"""Test input tensor annotation check of @jit decorator."""
import os
import logging
import pypto
import torch
logging.basicConfig(level=logging.INFO, format='', force=True)
def create_compute_func(shape, dtype):
"""Factory function that creates a new compute function with annotations."""
tiling = 16
@pypto.frontend.jit()
def compute_add(
a: pypto.Tensor(shape, dtype),
b: pypto.Tensor(shape, dtype),
c: pypto.Tensor(shape, dtype),
):
pypto.set_cube_tile_shapes([tiling, tiling], [tiling, tiling], [tiling, tiling])
c[:] = pypto.matmul(a, b, a.dtype)
return compute_add
def test_correct_annotation():
"""Test that correct annotation causes no error."""
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
shape = (16, 16)
dtype = pypto.DT_FP32
kernel = create_compute_func(shape, dtype)
kernel(torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'))
logging.info(f"✓ Verified: Correct annotation causes no error.")
def test_number_of_input_not_match():
"""Test that number of input not match causes an error."""
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
shape = (16, 16)
dtype = pypto.DT_FP32
kernel = create_compute_func(shape, dtype)
try:
kernel(torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'))
except pypto.error.FeError as e:
logging.info(f"✓ Verified: Number of input not match causes an error: {e}")
else:
raise RuntimeError("Number of input not match causes no error.")
def test_incorrect_number_of_dimensions_annotation():
"""Test that incorrect number of dimensions annotation causes an error."""
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
shape = (16, 16)
dtype = pypto.DT_FP32
kernel = create_compute_func(shape, dtype)
try:
kernel(torch.rand((16, 16, 16), dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'))
except pypto.error.FeError as e:
logging.info(f"✓ Verified: Incorrect number of dimensions annotation causes an error: {e}")
else:
raise ValueError("Incorrect number of dimensions annotation causes no error.")
def test_incorrect_shape_annotation():
"""Test that incorrect shape annotation causes an error."""
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
shape = (16, 16)
dtype = pypto.DT_FP32
kernel = create_compute_func(shape, dtype)
try:
kernel(torch.rand((16, 18), dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand((16, 16), dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand((16, 16), dtype=torch.float32, device=f'npu:{device_id}'))
except pypto.error.FeError as e:
logging.info(f"✓ Verified: Incorrect shape annotation causes an error: {e}")
else:
raise ValueError("Incorrect shape annotation causes no error.")
def test_incorrect_dtype_annotation():
"""Test that incorrect annotation causes an error."""
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
shape = (16, 16)
dtype = pypto.DT_FP16
kernel = create_compute_func(shape, dtype)
try:
kernel(torch.rand(shape, dtype=torch.float16, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float32, device=f'npu:{device_id}'),
torch.rand(shape, dtype=torch.float16, device=f'npu:{device_id}'))
except pypto.error.FeError as e:
logging.info(f"✓ Verified: Incorrect dtype annotation causes an error: {e}")
else:
raise ValueError("Incorrect dtype annotation causes no error.")
if __name__ == "__main__":
logging.info("=" * 70)
logging.info("Testing @pypto.frontend.jit decorator annotation check")
logging.info("=" * 70)
test_correct_annotation()
test_number_of_input_not_match()
test_incorrect_number_of_dimensions_annotation()
test_incorrect_shape_annotation()
test_incorrect_dtype_annotation()
logging.info("\n" + "=" * 70)
logging.info("All tests passed!")
logging.info("=" * 70)