"""
"""
import os
import pytest
import torch
import torch_npu
import pypto
import numpy as np
verify_options = {"enable_pass_verify": True,
"pass_verify_save_tensor": True,
}
@pypto.frontend.jit(runtime_options={"run_mode": pypto.RunMode.NPU},
verify_options=verify_options
)
def add_dyn_kernel(
x: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
y: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16)):
first_dim, second_dim = x.shape
view_shape, tile_shape = (64, 64), (32, 32)
first_view_shape, second_view_shape = view_shape
for b_idx in pypto.loop(int(np.ceil(first_dim / view_shape[0])), name="LOOP_L0", idx_name="b_idx"):
for s_idx in pypto.loop(int(np.ceil(second_dim / view_shape[1])), name="LOOP_L1", idx_name="s_idx"):
tile_tensor_0 = pypto.view(
x, view_shape,
[b_idx * first_view_shape, s_idx * second_view_shape]
)
tile_tensor_1 = pypto.view(
y, view_shape,
[b_idx * first_view_shape, s_idx * second_view_shape]
)
pypto.set_vec_tile_shapes(*tile_shape)
if b_idx < 2:
res = ((tile_tensor_0 * (tile_tensor_0 + tile_tensor_1)) - tile_tensor_1) * tile_tensor_1
else:
res = tile_tensor_0
pypto.assemble(
res,
[b_idx * first_view_shape, s_idx * second_view_shape],
out,
)
del res, tile_tensor_0, tile_tensor_1
def test_verify_dyn():
shape = [72, 144]
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
device = f'npu:{device_id}'
torch.npu.set_device(device_id)
a = torch.rand(shape, dtype=torch.float16, device=device)
b = torch.rand(shape, dtype=torch.float16, device=device)
output_data = torch.zeros(shape, dtype=torch.float16, device=device)
golden = (((a.float() * (a.float() + b.float())) - b.float()) * b.float()).half()
pypto.set_verify_golden_data(goldens=[None, None, golden.cpu()])
add_dyn_kernel(a, b, output_data)
assert torch.allclose(output_data, golden)
@pypto.frontend.jit(runtime_options={"run_mode": pypto.RunMode.NPU},
verify_options=verify_options
)
def cmp_where_kenrel(
a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32)):
for _ in pypto.loop(1):
pypto.set_vec_tile_shapes(16, 16)
mask = pypto.ge(a, 0.5)
out[:] = pypto.where(mask, 1.0, 0.0)
def test_verify_where():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
a = torch.rand((64, 64), dtype=torch.float16)
c = torch.zeros((64, 64))
golden = torch.where(a >= 0.5, 1.0, 0.0)
pypto.set_verify_golden_data(goldens=[None, golden])
inputs = [a.to(f"npu:{device_id}")]
outputs = [c.to(f"npu:{device_id}")]
cmp_where_kenrel(*inputs, *outputs)
assert torch.allclose(outputs[0].cpu(), golden)
@pypto.frontend.jit(runtime_options={"run_mode": pypto.RunMode.NPU})
def cmp_where_kenrel2(
a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP16),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32)):
for _ in pypto.loop(1):
pypto.set_vec_tile_shapes(16, 16)
mask = pypto.ge(a, 0.5)
out[:] = pypto.where(mask, 1.0, 0.0)
def test_verify_set_options():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
pypto.set_verify_options(**verify_options)
a = torch.rand((64, 64), dtype=torch.float16)
c = torch.zeros((64, 64))
golden = torch.where(a >= 0.5, 1.0, 0.0)
pypto.set_verify_golden_data(goldens=[None, golden])
inputs = [a.to(f"npu:{device_id}")]
outputs = [c.to(f"npu:{device_id}")]
cmp_where_kenrel2(*inputs, *outputs)
assert torch.allclose(outputs[0].cpu(), golden)