"""
"""
import os
import math
import copy
import numpy as np
import torch
import pypto
import pytest
import torch_npu
def test_scatterupdate_onboard():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
b = 1
s = 1
n = 1
d = 8
block_num = 1
block_size = 1
src_shape = (b, s, n, d)
index_shape = (b, s)
dst_shape = (block_num, block_size, n, d)
view_shape = (b, s, n, d)
tile_shape = (b, s, n, d)
pypto.runtime._device_init()
src_tensor = pypto.tensor(src_shape, pypto.DT_INT32, "PTO_TENSOR_SRC")
index_tensor = pypto.tensor(index_shape, pypto.DT_INT32, "PTO_TENSOR_INDEX")
update_tensor = pypto.tensor(dst_shape, pypto.DT_INT32, "PTO_TENSOR_DST")
dst_tensor = pypto.tensor(dst_shape, pypto.DT_INT32, "PTO_TENSOR_DST")
b_loop_num = math.ceil(src_shape[0] / view_shape[0])
s_loop_num = math.ceil(src_shape[1] / view_shape[1])
with pypto.function("MAIN", src_tensor, index_tensor, update_tensor, dst_tensor):
for b_idx in pypto.loop(b_loop_num, name="b0", idx_name="bidx"):
for s_idx in pypto.loop(s_loop_num, name="s0", idx_name="sidx"):
tmp_dst_tensor = pypto.tensor(
dst_shape, pypto.DT_INT32, "PTO_TENSOR_TMP")
view_tensor_src = pypto.view(src_tensor, view_shape,
[b_idx * view_shape[0], s_idx *
view_shape[1], 0, 0],
valid_shape=[
pypto.min(pypto.symbolic_scalar(
src_shape[0]) - b_idx * view_shape[0],
pypto.symbolic_scalar(view_shape[0])),
pypto.min(pypto.symbolic_scalar(
src_shape[1]) - s_idx * view_shape[1],
pypto.symbolic_scalar(view_shape[1])),
n, d
]
)
view_tensor_index = pypto.view(index_tensor, [view_shape[0], view_shape[1]],
[b_idx * view_shape[0],
s_idx * view_shape[1]],
valid_shape=[
pypto.min(pypto.symbolic_scalar(index_shape[0]) - b_idx * view_shape[0],
pypto.symbolic_scalar(view_shape[0])),
pypto.min(pypto.symbolic_scalar(index_shape[1]) - s_idx * view_shape[1],
pypto.symbolic_scalar(view_shape[1])),
],
)
view_tensor_dst = pypto.view(
update_tensor, dst_shape, [0, 0, 0, 0])
pypto.set_vec_tile_shapes(
tile_shape[0], tile_shape[1], tile_shape[2], tile_shape[3])
tmp_dst_tensor.move(pypto.scatter_update(
view_tensor_dst, -2, view_tensor_index, view_tensor_src))
pypto.set_vec_tile_shapes(1, 64, n, d)
pypto.assemble(tmp_dst_tensor, [0, 0, 0, 0], dst_tensor)
assert isinstance(dst_tensor, pypto.tensor)
input0_tensor = np.random.uniform(2, 3, src_shape).astype(np.int32)
input1_tensor = np.random.choice(
range(0, dst_shape[0] * dst_shape[1]), index_shape, replace=False).astype(np.int32)
input2_tensor = np.random.uniform(1, 2, dst_shape).astype(np.int32)
result = copy.copy(input2_tensor)
d_data = np.zeros(dst_shape[0] * dst_shape[1]
* dst_shape[2] * dst_shape[3]).astype(np.int32)
a_tensor = torch.from_numpy(input0_tensor)
b_tensor = torch.from_numpy(input1_tensor)
c_tensor = torch.from_numpy(input2_tensor)
d_tensor = torch.from_numpy(d_data)
pto_a_tensor = pypto.from_torch(a_tensor, "a_tensor")
pto_b_tensor = pypto.from_torch(b_tensor, "b_tensor")
pto_c_tensor = pypto.from_torch(c_tensor, "c_tensor")
pto_d_tensor = pypto.from_torch(d_tensor, "d_tensor")
pypto.runtime._device_run_once_data_from_host(
pto_a_tensor, pto_b_tensor, pto_c_tensor, pto_d_tensor)
for _b in range(b):
for _s in range(s):
result[input1_tensor[_b][_s] // block_size][input1_tensor[_b][_s] % block_size][:] \
= input0_tensor[_b][_s][:]
result_t = torch.from_numpy(result).to(d_tensor.device).to(d_tensor.dtype)
result_t = result_t.reshape_as(d_tensor)
assert (d_tensor == result_t).all().item()
pypto.runtime._device_fini()