import os
import math
from typing import List
import pypto
import torch
TORCH_TO_PTO_TYPES = {
torch.int8: pypto.DT_INT8,
torch.int16: pypto.DT_INT16,
torch.int32: pypto.DT_INT32,
torch.float16: pypto.DT_FP16,
torch.float32: pypto.DT_FP32,
torch.bfloat16: pypto.DT_BF16
}
class IndexAddArgs:
def __init__(self, axis: int, alpha, view_shape, tile_shape):
self.view_shape = view_shape
self.tile_shape = tile_shape
self.value = alpha
self.axis = axis
def indexadd_2dim_build(inputs: List[pypto.Tensor], args: IndexAddArgs):
src_shape = inputs[1].shape
view_shape = args.view_shape
tile_shape = args.tile_shape
axis = args.axis
value = args.value
b_loop_num = math.ceil(src_shape[0] / view_shape[0])
s_loop_num = math.ceil(src_shape[1] / view_shape[1])
with pypto.function("INDEXADD", inputs[0], inputs[1], inputs[2]):
for b_idx in pypto.loop(b_loop_num, name="LOOP_B0", idx_name="b_idx"):
for s_idx in pypto.loop(s_loop_num, name="LOOP_S0", idx_name="s_idx"):
pypto.set_vec_tile_shapes(tile_shape[0], tile_shape[1])
offsets = [b_idx * view_shape[0], s_idx * view_shape[1]]
src_valid_shape = [pypto.min(src_shape[0] - b_idx * view_shape[0], view_shape[0]),
pypto.min(src_shape[1] - s_idx * view_shape[1], view_shape[1])]
view_src = pypto.view(inputs[1], view_shape, offsets, valid_shape=src_valid_shape)
view_index = pypto.view(inputs[2], [view_shape[axis]], [offsets[axis]],
valid_shape=[src_valid_shape[axis]])
pypto.index_add_(inputs[0], axis, view_index, view_src, alpha=value)
del view_src, view_index
def run_indexadd(inputs: List[torch.Tensor], args: IndexAddArgs) -> None:
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 3))
torch.npu.set_device(device_id)
pypto.runtime._device_init()
inputs_tensors = [pypto.tensor(x.shape, TORCH_TO_PTO_TYPES[x.dtype]) for x in inputs]
indexadd_2dim_build(inputs_tensors, args)
pto_x1_tensor = pypto.from_torch(inputs[0], "x1_tensor")
pto_x2_tensor = pypto.from_torch(inputs[1], "x2_tensor")
pto_x3_tensor = pypto.from_torch(inputs[2], "x3_tensor")
pypto.runtime._device_run_once_data_from_host(pto_x1_tensor, pto_x2_tensor, pto_x3_tensor)
pypto.runtime._device_fini()
def test_indexadd__onboard():
axis = 0
alpha = 1.3
self_shape = [7, 28]
src_shape = [13, 28]
index_shape = [src_shape[axis]]
view_shape = [7, 28]
tile_shape = [5, 8]
args = IndexAddArgs(axis, alpha, view_shape, tile_shape)
inputs = [torch.rand(self_shape, dtype=torch.float32) * 200 - 100,
torch.rand(src_shape, dtype=torch.float32) * 200 - 100,
torch.randint(0, self_shape[axis], index_shape, dtype=torch.int32)]
golden = inputs[0].index_add(axis, inputs[2], inputs[1], alpha=alpha)
run_indexadd(inputs, args)
pypto_out = inputs[0]
assert torch.allclose(pypto_out.flatten(), golden.flatten(), rtol=1e-4, atol=1e-5)