"""
"""
import os
import torch
import torch_npu
import pypto
@pypto.frontend.jit()
def loop_scope(a: pypto.Tensor[[pypto.STATIC, pypto.STATIC], pypto.DT_INT32],
b: pypto.Tensor[[pypto.STATIC, pypto.STATIC], pypto.DT_INT32],
result: pypto.Tensor[[pypto.STATIC, pypto.STATIC], pypto.DT_INT32]):
pypto.set_vec_tile_shapes(64, 64)
for _ in pypto.loop(1, name="s0", idx_name="k"):
pypto.set_vec_tile_shapes(32, 32)
result.move(a + b)
for _ in pypto.loop(1, name="s0", idx_name="k"):
assert [64, 64] == pypto.get_vec_tile_shapes()
result.move(result + b)
def test_loop_scope():
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
tiling = 32
n, m = tiling * 1, tiling * 1
shape = (n, m)
a_data = torch.ones((n, m), dtype=torch.int32, device=f'npu:{device_id}') * 2
b_data = torch.ones((n, m), dtype=torch.int32, device=f'npu:{device_id}')
result = torch.zeros(shape, dtype=torch.int32, device=f'npu:{device_id}')
loop_scope(a_data, b_data, result)
torch_npu.npu.synchronize()
golden = torch.ones((n, m), dtype=torch.int32) * 4
assert torch.allclose(golden, result.cpu(), atol=1e-5)
if __name__ == "__main__":
test_loop_scope()