"""
Cast+Matmul 融合算子 ST 测试脚本。
场景:先 Cast 输入到目标 dtype,再执行 Matmul。
支持 pytest 参数化执行和直接执行两种模式。
"""
import os
import sys
import pytest
import pypto
import torch
import numpy as np
from testcase.matmul_ub2l1_test_case import (
CastMatmulConfig,
CAST_RIGHT_MATMUL_TESTS,
CAST_LEFT_MATMUL_TESTS,
CAST_BOTH_MATMUL_TESTS,
)
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def cast_matmul_pto_kernel(
a_tensor: pypto.Tensor(),
b_tensor: pypto.Tensor(),
out_tensor: pypto.Tensor(),
config: CastMatmulConfig,
):
m, k, n = config.shape
m_view, n_view = config.view_shape
pypto.set_cube_tile_shapes(*config.cube_tile_shape)
m_loop = (m + m_view - 1) // m_view
n_loop = (n + n_view - 1) // n_view
pypto.set_pass_options(sg_set_scope=10000)
for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L0_nIdx", idx_name="n_idx"):
mode = pypto.CastMode.CAST_NONE
if config.matmul_pto_dtype == pypto.DT_INT8:
mode = pypto.CastMode.CAST_TRUNC
if config.a_trans:
a_tile = a_tensor[:, m_idx * m_view: m_idx * m_view + m_view]
else:
a_tile = a_tensor[m_idx * m_view: m_idx * m_view + m_view, :]
if config.a_cast:
pypto.set_vec_tile_shapes(*config.a_vec_tile_shape)
a_compute = pypto.cast(a_tile, config.matmul_pto_dtype, mode)
else:
a_compute = a_tile
if config.b_trans:
b_tile = b_tensor[n_idx * n_view: n_idx * n_view + n_view, :]
else:
b_tile = b_tensor[:, n_idx * n_view: n_idx * n_view + n_view]
if config.b_cast:
pypto.set_vec_tile_shapes(*config.b_vec_tile_shape)
b_compute = pypto.cast(b_tile, config.matmul_pto_dtype, mode)
else:
b_compute = b_tile
out_view = pypto.matmul(
a_compute,
b_compute,
out_dtype=config.out_pto_dtype,
a_trans=config.a_trans,
b_trans=config.b_trans,
)
out_tensor[
m_idx * m_view: m_idx * m_view + m_view,
n_idx * n_view: n_idx * n_view + n_view,
] = out_view
pypto.set_pass_options(sg_set_scope=-1)
def run_cast_matmul_test(case: dict):
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
config = CastMatmulConfig.from_test_case(case)
m, k, n = config.shape
a_shape = [k, m] if config.a_trans else [m, k]
b_shape = [n, k] if config.b_trans else [k, n]
c_shape = [m, n]
a_input_torch_dtype = CastMatmulConfig.get_torch_dtype(case["a_input_dtype"])
b_input_torch_dtype = CastMatmulConfig.get_torch_dtype(case["b_input_dtype"])
c_torch_dtype = CastMatmulConfig.get_torch_dtype(case["out_dtype"])
if a_input_torch_dtype == torch.int8:
a_tensor_cpu = torch.randint(-5, 6, a_shape, dtype=a_input_torch_dtype)
else:
a_tensor_cpu = torch.rand(a_shape, dtype=a_input_torch_dtype)
if b_input_torch_dtype == torch.int8:
b_tensor_cpu = torch.randint(-5, 6, b_shape, dtype=b_input_torch_dtype)
else:
b_tensor_cpu = torch.rand(b_shape, dtype=b_input_torch_dtype)
matmul_dtype = CastMatmulConfig.get_torch_dtype(case["matmul_dtype"])
a_cpu = a_tensor_cpu.to(matmul_dtype).T if config.a_trans else a_tensor_cpu.to(matmul_dtype)
b_cpu = b_tensor_cpu.to(matmul_dtype).T if config.b_trans else b_tensor_cpu.to(matmul_dtype)
if matmul_dtype == torch.int8:
golden = torch.matmul(a_cpu.to(torch.int32), b_cpu.to(torch.int32)).to(c_torch_dtype)
else:
golden = torch.matmul(a_cpu.to(torch.float32), b_cpu.to(torch.float32)).to(c_torch_dtype)
a_tensor = a_tensor_cpu.to(f"npu:{device_id}")
b_tensor = b_tensor_cpu.to(f"npu:{device_id}")
c_tensor = torch.zeros(c_shape, dtype=c_torch_dtype, device=f"npu:{device_id}")
cast_matmul_pto_kernel(a_tensor, b_tensor, c_tensor, config)
atol, rtol = CastMatmulConfig.get_tolerance(case["out_dtype"])
assert torch.allclose(
c_tensor.cpu(), golden.cpu(), atol=atol, rtol=rtol
), f"Test case {case['id']} ({case['name']}) failed"
ALL_CAST_MATMUL_TESTS = (
CAST_RIGHT_MATMUL_TESTS
+ CAST_LEFT_MATMUL_TESTS
+ CAST_BOTH_MATMUL_TESTS
)
@pytest.mark.parametrize("case", [
pytest.param(case, marks=pytest.mark.soc(*case["products"]))
for case in ALL_CAST_MATMUL_TESTS
])
def test_cast_matmul(case: dict):
run_cast_matmul_test(case)
def run_cast_matmul_demo(run_mode):
m_size, k_size, n_size = 256, 256, 256
m_view_size, n_view_size = 128, 128
if run_mode == "npu":
mode = pypto.RunMode.NPU
elif run_mode == "sim":
mode = pypto.RunMode.SIM
else:
raise ValueError(f"Invalid run_mode: {run_mode}. Must be 'npu' or 'sim'")
@pypto.frontend.jit(
debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0},
runtime_options={"run_mode": mode}
)
def cast_matmul_demo_kernel(
a: pypto.Tensor([], pypto.DT_FP32),
b: pypto.Tensor([], pypto.DT_FP16),
out: pypto.Tensor([], pypto.DT_FP16),
):
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
pypto.set_pass_options(sg_set_scope=10000)
m_loop = (m_size + m_view_size - 1) // m_view_size
n_loop = (n_size + n_view_size - 1) // n_view_size
for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L0_nIdx", idx_name="n_idx"):
a_tile = a[m_idx * m_view_size: m_idx * m_view_size + m_view_size, :]
pypto.set_vec_tile_shapes(m_view_size, k_size)
a_fp16_tile = pypto.cast(a_tile, pypto.DT_FP16)
b_view = b[:, n_idx * n_view_size: n_idx * n_view_size + n_view_size]
out_view = pypto.matmul(a_fp16_tile, b_view, pypto.DT_FP16)
out[
m_idx * m_view_size: m_idx * m_view_size + m_view_size,
n_idx * n_view_size: n_idx * n_view_size + n_view_size,
] = out_view
device = "npu:0" if run_mode == "npu" else "cpu"
a = torch.randn([m_size, k_size], dtype=torch.float32, device=device)
b = torch.randn([k_size, n_size], dtype=torch.float16, device=device)
out = torch.empty(m_size, n_size, dtype=torch.float16, device=device)
cast_matmul_demo_kernel(a, b, out)
if __name__ == "__main__":
run_cast_matmul_demo("npu")