import os
import pypto
import torch
import torch_npu
@pypto.frontend.jit()
def cust_dyn_func_add(a: pypto.Tensor[[], pypto.DT_INT32],
b: pypto.Tensor[[], pypto.DT_INT32],
c: pypto.Tensor[[], pypto.DT_INT32]):
pypto.set_vec_tile_shapes(32, 32)
c.move(a + b)
@pypto.frontend.jit()
def cust_dyn_func_sub(a: pypto.Tensor[[...], pypto.DT_INT32],
b: pypto.Tensor[[...], pypto.DT_INT32],
c: pypto.Tensor[[...], pypto.DT_INT32]):
pypto.set_vec_tile_shapes(32, 32)
c.move(a - b)
def device_run(is_run_add):
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
tiling = 32
n, m = tiling * 1, tiling * 1
shape = (n, m)
a_rawdata = torch.ones((n, m)) * 2
a_data = a_rawdata.to(dtype=torch.int32, device=f'npu:{device_id}')
b_rawdata = torch.ones((n, m))
b_data = b_rawdata.to(dtype=torch.int32, device=f'npu:{device_id}')
if is_run_add:
add_result = torch.zeros(shape, dtype=torch.int32, device=f'npu:{device_id}')
cust_dyn_func_add(a_data, b_data, add_result)
torch_npu.npu.synchronize()
golden = torch.ones((n, m), dtype=torch.int32) * 3
assert torch.allclose(golden.int(), add_result.cpu(), atol=1e-5)
else:
sub_result = torch.zeros(shape, dtype=torch.int32, device=f'npu:{device_id}')
cust_dyn_func_sub(a_data, b_data, sub_result)
torch_npu.npu.synchronize()
golden = torch.ones((n, m))
assert torch.allclose(golden.int(), sub_result.cpu(), atol=1e-5)
def test_run_multi_jit():
device_run(True)
device_run(False)
device_run(True)