"""测试 pypto.frontend.jit 对非 pypto.Tensor 参数的支持。"""
import os
import pypto
import torch
import torch_npu
@pypto.frontend.jit(runtime_options={"run_mode": pypto.RunMode.NPU})
def add_kernel(
a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
b: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
scalar=0,
):
pypto.set_vec_tile_shapes(16, 16)
out.move(a + b + scalar)
def test_add_with_kwargs():
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
a = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
b = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
r = torch.zeros(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
add_kernel(a, b, r, scalar=1)
assert r.shape == (32, 32)
assert torch.allclose(r.cpu().float(), torch.ones(32, 32) * 2 + 1)
@pypto.frontend.jit(runtime_options={"run_mode": pypto.RunMode.NPU})
def add_npu_with_tiling(
a: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
b: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
out: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_INT32),
tiling=None,
scalar=1,
):
pypto.set_vec_tile_shapes(tiling, tiling)
out.move(a + b + scalar)
def test_add_npu_with_tiling():
device_id = os.environ.get("TILE_FWK_DEVICE_ID", 0)
torch.npu.set_device(int(device_id))
a = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
b = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
r1 = torch.zeros(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
add_npu_with_tiling(a, b, r1, 32, scalar=2)
assert torch.allclose(r1.cpu().float(), torch.ones(32, 32) * 2 + 2)
r2 = torch.zeros(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
add_npu_with_tiling(a, b, r2, 16)
assert torch.allclose(r2.cpu().float(), torch.ones(32, 32) * 2 + 1)