"""
"""
import pypto
import pytest
import torch
def _to_device_tensor_data(tensor: torch.Tensor, shape=None, dtype=pypto.DT_FP32):
target_shape = list(tensor.shape) if shape is None else shape
return pypto.pypto_impl.DeviceTensorData(dtype, tensor.data_ptr(), target_shape)
@pytest.mark.skip(reason="Verify not supported")
def test_verify_dynamic_ops_assemble():
s = 32
n = 2
m = 1
t0 = pypto.tensor((n * s, m * s), pypto.DT_FP32)
t1 = pypto.tensor((n * s, m * s), pypto.DT_FP32)
out = pypto.tensor((n * s, m * s), pypto.DT_FP32)
t0_data = torch.ones((n * s, m * s))
t1_data = torch.ones((n * s, m * s)) * 2
out_data = torch.zeros((n * s, m * s))
golden = torch.ones((n * s, m * s)) * 3
pypto.set_verify_options(
enable_pass_verify=True
)
pypto.set_verify_golden_data([t0_data, t1_data, out_data], [t0_data, t1_data, golden])
with pypto.function("main", t0, t1, out):
pypto.set_vec_tile_shapes(8, 8)
for idx in pypto.loop(10):
pypto.pass_verify_print(t0)
t0a = pypto.view(t0, [s, s], [0, 0])
t0b = pypto.view(t0, [s, s], [s, 0])
t1a = pypto.view(t1, [s, s], [0, 0])
t1b = pypto.view(t1, [s, s], [s, 0])
t2a = t0a + t1a
t2b = t0b + t1b
pypto.assemble(t2a, [0, 0], out)
pypto.assemble(t2b, [s, 0], out)
pypto.pass_verify_save(out, "tensor_out_idx$idx", idx=idx)
def test_set_verify_data_dtype_mismatch_intercept():
input_tensor = torch.zeros((2, 3), dtype=torch.float32)
output_tensor = torch.zeros((2, 3), dtype=torch.float32)
input_golden = torch.zeros((2, 3), dtype=torch.float32)
output_golden = torch.zeros((2, 3), dtype=torch.float32)
input_data = _to_device_tensor_data(input_tensor, dtype=pypto.DT_FP32)
output_data = _to_device_tensor_data(output_tensor, dtype=pypto.DT_FP32)
input_golden_data = _to_device_tensor_data(input_golden, dtype=pypto.DT_FP32)
output_golden_data = _to_device_tensor_data(output_golden, dtype=pypto.DT_FP16)
with pytest.raises(Exception, match="ErrCode:\\s*FB4003"):
pypto.pypto_impl.SetVerifyData([input_data], [output_data], [input_golden_data, output_golden_data])
def test_set_verify_data_shape_mismatch_intercept():
input_tensor = torch.zeros((2, 3), dtype=torch.float32)
output_tensor = torch.zeros((2, 3), dtype=torch.float32)
input_golden = torch.zeros((2, 3), dtype=torch.float32)
output_golden = torch.zeros((2, 3), dtype=torch.float32)
input_data = _to_device_tensor_data(input_tensor)
input_golden_data = _to_device_tensor_data(input_golden)
output_golden_exact = _to_device_tensor_data(output_golden, [2, 3])
output_bad_shape = _to_device_tensor_data(output_tensor, [2, 4])
with pytest.raises(Exception, match="ErrCode:\\s*FB4002"):
pypto.pypto_impl.SetVerifyData([input_data], [output_bad_shape], [input_golden_data, output_golden_exact])