"""
"""
import pypto
def test_matrix_matmul():
dtype = pypto.DT_FP32
a = pypto.tensor((32, 64), dtype, "A")
b = pypto.tensor((64, 32), dtype, "B")
c = None
with pypto.function("MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = pypto.matmul(a, b, dtype)
d = pypto.matmul(a, b, dtype, a_trans=True, b_trans=True)
assert isinstance(c, pypto.tensor)
assert c.shape == [32, 32]
assert isinstance(d, pypto.tensor)
assert d.shape == [64, 64]
def test_matrix_batch_matmul():
dtype = pypto.DT_FP32
a = pypto.tensor((2, 64, 32), dtype, "A")
b = pypto.tensor((2, 32, 64), dtype, "B")
c = None
with pypto.function("BATCH_MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = pypto.matmul(a, b, dtype)
d = pypto.matmul(a, b, dtype, a_trans=True, b_trans=True)
assert isinstance(c, pypto.tensor)
assert c.shape == [2, 64, 64]
assert isinstance(d, pypto.tensor)
assert d.shape == [2, 32, 32]
def test_matrix_matmul_with_syntactic_sugar():
dtype = pypto.DT_FP16
a = pypto.tensor((64, 32), dtype, "A")
b = pypto.tensor((32, 64), dtype, "B")
c = None
with pypto.function("MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = a @ b
assert isinstance(c, pypto.tensor)
assert c.dtype == pypto.DT_FP16
assert c.shape == [64, 64]
def test_matrix_matmul_with_tensor_interface():
input_dtype = pypto.DT_INT8
out_dtype = pypto.DT_INT32
a = pypto.tensor((3, 64, 32), input_dtype, "A")
b = pypto.tensor((3, 32, 64), input_dtype, "B")
c = None
with pypto.function("BATCH_MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = a.matmul(b, out_dtype, a_trans=True, b_trans=True)
assert isinstance(c, pypto.tensor)
assert c.dtype == pypto.DT_INT32
assert c.shape == [3, 32, 32]
def test_matrix_matmul_with_trans_mode_cast_rint():
input_dtype = pypto.DT_FP32
out_dtype = pypto.DT_FP32
a = pypto.tensor((64, 32), input_dtype, "A")
b = pypto.tensor((32, 64), input_dtype, "B")
c = None
with pypto.function("MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = pypto.matmul(a, b, out_dtype, extend_params={"trans_mode": pypto.TransMode.CAST_RINT})
assert isinstance(c, pypto.tensor)
assert c.shape == [64, 64]
def test_matrix_matmul_with_trans_mode_cast_rint():
input_dtype = pypto.DT_FP32
out_dtype = pypto.DT_FP32
a = pypto.tensor((64, 32), input_dtype, "A")
b = pypto.tensor((32, 64), input_dtype, "B")
c = None
with pypto.function("MATMUL", a, b):
pypto.set_cube_tile_shapes([64, 64], [64, 64], [64, 64])
c = pypto.matmul(a, b, out_dtype, extend_params={"trans_mode": pypto.TransMode.CAST_ROUND})
assert isinstance(c, pypto.tensor)
assert c.shape == [64, 64]