"""Test pypto.frontend.jit kernel dynamic shape error behavior."""
import logging
import pytest
import torch
import pypto
logging.basicConfig(level=logging.INFO, format="", force=True)
SIM_RUNTIME_OPTIONS = {"run_mode": pypto.RunMode.SIM}
DYNAMIC_SHAPE_ERROR = "Dynamic shape tensors are not allowed as operation operands"
@pypto.frontend.jit(runtime_options=SIM_RUNTIME_OPTIONS)
def kernel_with_dynamic(
a: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP32),
):
pypto.set_vec_tile_shapes(16, 16)
out[:] = pypto.exp(a)
def test_kernel_dynamic_shape_error():
"""Test that dynamic shape (-1) in exp kernel operand causes CheckTensorDynamicShape error."""
a = torch.ones(1, 8, dtype=torch.float32)
out = torch.zeros(1, 8, dtype=torch.float32)
with pytest.raises(Exception, match=DYNAMIC_SHAPE_ERROR):
kernel_with_dynamic(a, out)
@pypto.frontend.jit(runtime_options=SIM_RUNTIME_OPTIONS)
def matmul_kernel(
a: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP16),
b: pypto.Tensor([], pypto.DT_FP16),
out: pypto.Tensor([], pypto.DT_FP16),
):
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
out[:] = pypto.matmul(a, b, out_dtype=pypto.DT_FP16)
def test_matmul_dynamic_shape_error():
"""Test that dynamic shape (-1) in matmul kernel operand causes CheckTensorDynamicShape error."""
m, k, n = 128, 128, 128
a = torch.rand([m, k], dtype=torch.float16)
b = torch.rand([k, n], dtype=torch.float16)
out = torch.zeros([m, n], dtype=torch.float16)
with pytest.raises(Exception, match=DYNAMIC_SHAPE_ERROR):
matmul_kernel(a, b, out)
@pypto.frontend.jit(runtime_options=SIM_RUNTIME_OPTIONS)
def one_hot_kernel(
a: pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_INT32),
out: pypto.Tensor([], pypto.DT_INT32),
):
pypto.set_vec_tile_shapes(4, 5, 32)
out[:] = pypto.one_hot(a, 5)
def test_one_hot_dynamic_shape_error():
"""Test that dynamic shape (-1) in one_hot kernel operand causes CheckTensorDynamicShape error."""
x = torch.tensor([0, 2, 4], dtype=torch.int32)
out = torch.zeros([3, 5], dtype=torch.int32)
with pytest.raises(Exception, match=DYNAMIC_SHAPE_ERROR):
one_hot_kernel(x, out)
@pypto.frontend.jit(runtime_options=SIM_RUNTIME_OPTIONS)
def kernel_dynamic_reshape(
q: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP32),
):
sq = 128
d = 64
b = q.shape[0]
q_reshaped = pypto.reshape(q, [b * sq, d])
pypto.set_vec_tile_shapes(64, 64)
for idx in pypto.loop(b):
temp = pypto.view(q_reshaped, [sq, d], [idx * sq, 0])
temp1 = pypto.exp(temp)
pypto.assemble(temp1, [idx * sq, 0], out)
def test_dynamic_reshape_error():
"""Test that dynamic reshape without inplace=True causes an error."""
q = torch.ones(1, 128, 64, dtype=torch.float32)
out = torch.zeros(128, 64, dtype=torch.float32)
with pytest.raises(Exception, match="reshape\\(\\) requires integer shape when using non-inplace reshape"):
kernel_dynamic_reshape(q, out)
@pypto.frontend.jit(runtime_options=SIM_RUNTIME_OPTIONS)
def kernel_view_valid_shape_reshape_inplace_error(
q: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_FP32),
):
sq = 128
d = 64
b = q.shape[0]
pypto.set_vec_tile_shapes(64, 64)
for idx in pypto.loop(b):
q_view = pypto.view(q, [1, sq, d], [idx, 0, 0], valid_shape=[1, sq, d])
q_reshaped = pypto.reshape(q_view, [sq, d], valid_shape=[sq, d], inplace=True)
pypto.assemble(q_reshaped, [idx * sq, 0], out)
def test_view_valid_shape_reshape_inplace_error():
"""Test that reshape with valid_shape and inplace=True after dynamic view causes an error."""
q = torch.ones(1, 128, 64, dtype=torch.float32)
out = torch.zeros(128, 64, dtype=torch.float32)
with pytest.raises(
Exception,
match="Reshape\\(Inplace=true\\) is not supported for tensors derived from dynamic view",
):
kernel_view_valid_shape_reshape_inplace_error(q, out)
if __name__ == "__main__":
test_kernel_dynamic_shape_error()
test_matmul_dynamic_shape_error()
test_one_hot_dynamic_shape_error()
test_dynamic_reshape_error()
test_view_valid_shape_reshape_inplace_error()
logging.info("All dynamic shape error tests passed.")