"""测试 pypto.frontend.jit 对非 pypto.Tensor 参数的支持。"""
import os
import pypto
import pytest
import torch
import torch_npu
@pytest.mark.skip("Failed for sync")
def test_add_with_kwargs_run():
@pypto.frontend.function
def sub_add_kernel(
a: pypto.Tensor([], pypto.DT_INT32),
b: pypto.Tensor(dtype=pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
res.move(a + b + c + d + scalar)
@pypto.frontend.jit(
runtime_options={"run_mode": pypto.RunMode.NPU}
)
def add_kernel(
a: pypto.Tensor(dtype=pypto.DT_INT32),
b: pypto.Tensor([], pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
pypto.set_vec_tile_shapes(16, 16)
sub_add_kernel(a, b, c, d, res, scalar)
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
a = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
b = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
c = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
d = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
res = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
add_kernel(a, b, c, d, res, scalar=1)
assert res.shape == (32, 32)
assert torch.allclose(res.cpu().float(), torch.ones(32, 32) * 4 + 1)
def test_add_with_kwargs_check_stable():
@pypto.frontend.function
def sub_add_kernel(
a: pypto.Tensor([], pypto.DT_INT32),
b: pypto.Tensor(dtype=pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
res.move(a + b + c + d + scalar)
@pypto.frontend.jit(
runtime_options={"run_mode": pypto.RunMode.NPU}
)
def add_kernel(
a: pypto.Tensor(dtype=pypto.DT_INT32),
b: pypto.Tensor([], pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
pypto.set_vec_tile_shapes(16, 16)
sub_add_kernel(a, b, c, d, res, scalar)
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
a = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
b = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
c = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
d = torch.ones(32, 31, dtype=torch.int32, device=f"npu:{device_id}")
res = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
try:
add_kernel(a, b, c, d, res, scalar=1)
except Exception as e:
assert "does not match the shape" in str(e)
def test_add_with_kwargs_check_dtype():
@pypto.frontend.function
def sub_add_kernel(
a: pypto.Tensor([], pypto.DT_INT32),
b: pypto.Tensor(dtype=pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
res.move(a + b + c + d + scalar)
@pypto.frontend.jit(
runtime_options={"run_mode": pypto.RunMode.NPU}
)
def add_kernel(
a: pypto.Tensor(dtype=pypto.DT_INT32),
b: pypto.Tensor([], pypto.DT_INT32),
c: pypto.Tensor([...], pypto.DT_INT32),
d: pypto.Tensor([pypto.STATIC, 32], pypto.DT_INT32),
res: pypto.Tensor([32, 32], pypto.DT_INT32),
scalar=0
):
pypto.set_vec_tile_shapes(16, 16)
sub_add_kernel(a, b, c, d, res, scalar)
device_id = os.environ.get('TILE_FWK_DEVICE_ID', 0)
torch.npu.set_device(int(device_id))
a = torch.ones(32, 32, dtype=torch.float32, device=f"npu:{device_id}")
b = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
c = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
d = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
res = torch.ones(32, 32, dtype=torch.int32, device=f"npu:{device_id}")
try:
add_kernel(a, b, c, d, res, scalar=1)
except Exception as e:
assert "does not match the dtype" in str(e)