"""
"""
import os
import pypto
import torch
import torch_npu
import numpy as np
from numpy.testing import assert_allclose
B = 3
S = 4
N1 = 64
D = 64
@pypto.frontend.jit()
def dyn_loop_with_loop_begin(
in_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32)
):
pypto.set_vec_tile_shapes(1, 1, 64, 64)
for b_idx in pypto.loop(B, name="b_loop", idx_name="b_idx"):
for s_idx in pypto.loop(S, name="s_loop", idx_name="s_idx"):
a0 = pypto.view(in_tensor, [1, 1, N1, D], [b_idx, s_idx, 0, 0])
if pypto.is_loop_begin(b_idx):
a1 = pypto.add(a0, 1.0)
pypto.assemble(a1, [b_idx, s_idx, 0, 0], out_tensor)
else:
a1 = pypto.mul(a0, 1.0)
pypto.assemble(a1, [b_idx, s_idx, 0, 0], out_tensor)
def test_is_loop_begin():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch.manual_seed(42)
shape_in = (B, S, N1, D)
shape_out = (B, S, N1, D)
input_torch = torch.rand(shape_in, dtype=torch.float32, device=f'npu:{device_id}')
output_torch = torch.ones(shape_out, dtype=torch.float32, device=f'npu:{device_id}')
dyn_loop_with_loop_begin(input_torch, output_torch)
torch_npu.npu.synchronize()
output_result = output_torch.cpu()
output_golden = input_torch.clone().cpu()
output_golden[0:1, :, :, :] = output_golden[0:1, :, :, :] + 1
assert torch.allclose(output_result, output_golden, atol=1e-5)
@pypto.frontend.jit()
def dyn_loop_with_loop_end(
in_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32)
):
pypto.set_vec_tile_shapes(1, 1, 64, 64)
for b_idx in pypto.loop(B, name="b_loop", idx_name="b_idx"):
for s_idx in pypto.loop(S, name="s_loop", idx_name="s_idx"):
a0 = pypto.view(in_tensor, [1, 1, N1, D], [b_idx, s_idx, 0, 0])
if pypto.is_loop_end(b_idx):
a1 = pypto.add(a0, 1.0)
pypto.assemble(a1, [b_idx, s_idx, 0, 0], out_tensor)
else:
a1 = pypto.mul(a0, 1.0)
pypto.assemble(a1, [b_idx, s_idx, 0, 0], out_tensor)
def test_is_loop_end():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch.manual_seed(42)
shape_in = (B, S, N1, D)
shape_out = (B, S, N1, D)
input_torch = torch.rand(shape_in, dtype=torch.float32, device=f'npu:{device_id}')
output_torch = torch.ones(shape_out, dtype=torch.float32, device=f'npu:{device_id}')
dyn_loop_with_loop_end(input_torch, output_torch)
torch_npu.npu.synchronize()
output_result = output_torch.cpu()
output_golden = input_torch.clone().cpu()
output_golden[B - 1:B, :, :, :] = output_golden[B - 1:B, :, :, :] + 1
assert torch.allclose(output_result, output_golden, atol=1e-5)
if __name__ == "__main__":
test_is_loop_begin()
test_is_loop_end()