"""
BatchMatmul BASIC_TESTS test script.
Supports both pytest and direct execution modes.
"""
import os
from dataclasses import dataclass
from typing import Tuple
import pytest
import pypto
import torch
import torch_npu
from testcase.batchmatmul_test_case import BASIC_3D_TESTS, BASIC_4D_TESTS, BatchMatmulConfig
@dataclass
class SliceContext3D:
"""3D 切片与分块参数"""
batch_slice: Tuple[int, int]
offset: int
k: int
tile_size: int
@dataclass
class SliceContext4D:
"""4D 切片与分块参数"""
b0_slice: Tuple[int, int]
b1_slice: Tuple[int, int]
offset: int
k: int
tile_size: int
@dataclass
class TileProcessContext4D:
"""4D tile处理所需的统一参数封装"""
a_tensor: any
b_tensor: any
out_tensor: any
config: BatchMatmulConfig
b0_idx: int
b1_idx: int
m_idx: int = 0
n_idx: int = 0
tile_b0: int = 0
tile_b1: int = 0
tile_m: int = 0
tile_n: int = 0
k: int = 0
b0_a: int = 0
b1_a: int = 0
b0_b: int = 0
b1_b: int = 0
m_loop: int = 0
n_loop: int = 0
def get_batch_slice(batch_size, tile_size, offset, other_batch_size):
if batch_size == 1 and batch_size != other_batch_size:
return 0, 1
return offset, offset + tile_size
def get_a_view_3d(a_tensor, config, ctx: SliceContext3D):
start, end = ctx.batch_slice
if config.a_trans:
return a_tensor[start:end, 0:ctx.k, ctx.offset:ctx.offset + ctx.tile_size]
return a_tensor[start:end, ctx.offset:ctx.offset + ctx.tile_size, 0:ctx.k]
def get_b_view_3d(b_tensor, config, ctx: SliceContext3D):
start, end = ctx.batch_slice
if config.b_trans:
return b_tensor[start:end, ctx.offset:ctx.offset + ctx.tile_size, 0:ctx.k]
return b_tensor[start:end, 0:ctx.k, ctx.offset:ctx.offset + ctx.tile_size]
def get_a_view_4d(a_tensor, config, ctx: SliceContext4D):
b0_start, b0_end = ctx.b0_slice
b1_start, b1_end = ctx.b1_slice
if config.a_trans:
return a_tensor[b0_start:b0_end, b1_start:b1_end, 0:ctx.k, ctx.offset:ctx.offset + ctx.tile_size]
return a_tensor[b0_start:b0_end, b1_start:b1_end, ctx.offset:ctx.offset + ctx.tile_size, 0:ctx.k]
def get_b_view_4d(b_tensor, config, ctx: SliceContext4D):
b0_start, b0_end = ctx.b0_slice
b1_start, b1_end = ctx.b1_slice
if config.b_trans:
return b_tensor[b0_start:b0_end, b1_start:b1_end, ctx.offset:ctx.offset + ctx.tile_size, 0:ctx.k]
return b_tensor[b0_start:b0_end, b1_start:b1_end, 0:ctx.k, ctx.offset:ctx.offset + ctx.tile_size]
def process_tile_4d(ctx: TileProcessContext4D):
"""处理单个4D tile的矩阵乘计算(参数已封装)"""
m_offset = ctx.m_idx * ctx.tile_m
n_offset = ctx.n_idx * ctx.tile_n
b0_offset = ctx.b0_idx * ctx.tile_b0
b1_offset = ctx.b1_idx * ctx.tile_b1
b0_a_slice = get_batch_slice(ctx.b0_a, ctx.tile_b0, b0_offset, ctx.b0_b)
b1_a_slice = get_batch_slice(ctx.b1_a, ctx.tile_b1, b1_offset, ctx.b1_b)
b0_b_slice = get_batch_slice(ctx.b0_b, ctx.tile_b0, b0_offset, ctx.b0_a)
b1_b_slice = get_batch_slice(ctx.b1_b, ctx.tile_b1, b1_offset, ctx.b1_a)
ctx_a = SliceContext4D(b0_slice=b0_a_slice, b1_slice=b1_a_slice,
offset=m_offset, k=ctx.k, tile_size=ctx.tile_m)
ctx_b = SliceContext4D(b0_slice=b0_b_slice, b1_slice=b1_b_slice,
offset=n_offset, k=ctx.k, tile_size=ctx.tile_n)
a_view = get_a_view_4d(ctx.a_tensor, ctx.config, ctx_a)
b_view = get_b_view_4d(ctx.b_tensor, ctx.config, ctx_b)
out_view = pypto.matmul(a_view, b_view,
out_dtype=ctx.config.out_dtype,
a_trans=ctx.config.a_trans, b_trans=ctx.config.b_trans)
ctx.out_tensor[b0_offset:b0_offset + ctx.tile_b0,
b1_offset:b1_offset + ctx.tile_b1,
m_offset:m_offset + ctx.tile_m,
n_offset:n_offset + ctx.tile_n] = out_view
def process_mn_loops_4d(ctx: TileProcessContext4D):
"""处理单个batch pair的MN循环(降低主函数嵌套深度)"""
for m_idx in pypto.loop(0, ctx.m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, ctx.n_loop, 1, name="LOOP_L1_nIdx", idx_name="n_idx"):
ctx.m_idx = m_idx
ctx.n_idx = n_idx
process_tile_4d(ctx)
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def batch_matmul_kernel_3d(
a_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
b_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
out_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC]),
config: BatchMatmulConfig,
):
b, m, k, n = config.get_logical_dims_3d()
pypto.set_cube_tile_shapes(*config.tile_shape, config.is_acc)
pypto.set_vec_tile_shapes(128, 128)
tile_b = config.view_shape[0]
tile_m = config.view_shape[1]
tile_n = config.view_shape[2]
batch_a = config.a_shape[0]
batch_b = config.b_shape[0]
m_loop = (m + tile_m - 1) // tile_m
n_loop = (n + tile_n - 1) // tile_n
b_loop = (b + tile_b - 1) // tile_b
pypto.set_matrix_size([m, k, n])
for b_idx in pypto.loop(0, b_loop, 1, name="LOOP_L0_bIdx", idx_name="b_idx"):
for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L1_nIdx", idx_name="n_idx"):
m_offset = m_idx * tile_m
n_offset = n_idx * tile_n
b_offset = b_idx * tile_b
batch_a_slice = get_batch_slice(batch_a, tile_b, b_offset, batch_b)
batch_b_slice = get_batch_slice(batch_b, tile_b, b_offset, batch_a)
a_ctx = SliceContext3D(batch_slice=batch_a_slice, offset=m_offset, k=k, tile_size=tile_m)
b_ctx = SliceContext3D(batch_slice=batch_b_slice, offset=n_offset, k=k, tile_size=tile_n)
a_view = get_a_view_3d(a_tensor, config, a_ctx)
b_view = get_b_view_3d(b_tensor, config, b_ctx)
out_view = pypto.matmul(a_view, b_view,
out_dtype=config.out_dtype,
a_trans=config.a_trans, b_trans=config.b_trans)
out_tensor[b_offset:b_offset + tile_b,
m_offset:m_offset + tile_m,
n_offset:n_offset + tile_n] = out_view
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def batch_matmul_kernel_4d(
a_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC]),
b_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC]),
out_tensor: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC]),
config: BatchMatmulConfig,
):
b0, b1, m, k, n = config.get_logical_dims_4d()
b0_a = config.a_shape[0]
b1_a = config.a_shape[1]
b0_b = config.b_shape[0]
b1_b = config.b_shape[1]
pypto.set_cube_tile_shapes(*config.tile_shape, config.is_acc)
pypto.set_vec_tile_shapes(1, 128, 128)
tile_b0 = config.view_shape[0]
tile_b1 = config.view_shape[1]
tile_m = config.view_shape[2]
tile_n = config.view_shape[3]
m_loop = (m + tile_m - 1) // tile_m
n_loop = (n + tile_n - 1) // tile_n
b0_loop = (b0 + tile_b0 - 1) // tile_b0
b1_loop = (b1 + tile_b1 - 1) // tile_b1
pypto.set_matrix_size([m, k, n])
for b0_idx in pypto.loop(0, b0_loop, 1, name="LOOP_L0_b0Idx", idx_name="b0_idx"):
for b1_idx in pypto.loop(0, b1_loop, 1, name="LOOP_L0_b1Idx", idx_name="b1_idx"):
ctx = TileProcessContext4D(
a_tensor=a_tensor, b_tensor=b_tensor, out_tensor=out_tensor,
config=config, b0_idx=b0_idx, b1_idx=b1_idx,
tile_b0=tile_b0, tile_b1=tile_b1, tile_m=tile_m, tile_n=tile_n,
k=k, b0_a=b0_a, b1_a=b1_a, b0_b=b0_b, b1_b=b1_b,
m_loop=m_loop, n_loop=n_loop
)
process_mn_loops_4d(ctx)
def prepare_tensors_3d(config, a_dtype, b_dtype, c_dtype, device_id):
b, m, k, n = config.get_logical_dims_3d()
a_shape = config.a_shape
b_shape = config.b_shape
c_shape = [b, m, n]
if a_dtype == torch.int8:
a_tensor_cpu = torch.randint(-2, 3, a_shape, dtype=a_dtype)
b_tensor_cpu = torch.randint(-2, 3, b_shape, dtype=b_dtype)
else:
a_tensor_cpu = torch.rand(a_shape, dtype=a_dtype)
b_tensor_cpu = torch.rand(b_shape, dtype=b_dtype)
a_cpu = a_tensor_cpu.transpose(1, 2) if config.a_trans else a_tensor_cpu
b_cpu = b_tensor_cpu.transpose(1, 2) if config.b_trans else b_tensor_cpu
accum_dtype = torch.int32 if a_dtype == torch.int8 else torch.float32
golden = torch.matmul(a_cpu.to(accum_dtype), b_cpu.to(accum_dtype)).to(c_dtype)
a_tensor = a_tensor_cpu.to(f"npu:{device_id}")
b_tensor = b_tensor_cpu.to(f"npu:{device_id}")
if config.a_format == "NZ":
a_tensor = torch_npu.npu_format_cast(a_tensor, 29)
if config.b_format == "NZ":
b_tensor = torch_npu.npu_format_cast(b_tensor, 29)
c_tensor = torch.zeros(c_shape, dtype=c_dtype, device=f"npu:{device_id}")
return a_tensor, b_tensor, c_tensor, golden
def prepare_tensors_4d(config, a_dtype, b_dtype, c_dtype, device_id):
b0, b1, m, k, n = config.get_logical_dims_4d()
a_shape = config.a_shape
b_shape = config.b_shape
c_shape = [b0, b1, m, n]
if a_dtype == torch.int8:
a_tensor_cpu = torch.randint(-2, 3, a_shape, dtype=a_dtype)
b_tensor_cpu = torch.randint(-2, 3, b_shape, dtype=b_dtype)
else:
a_tensor_cpu = torch.rand(a_shape, dtype=a_dtype)
b_tensor_cpu = torch.rand(b_shape, dtype=b_dtype)
a_cpu = a_tensor_cpu.transpose(2, 3) if config.a_trans else a_tensor_cpu
b_cpu = b_tensor_cpu.transpose(2, 3) if config.b_trans else b_tensor_cpu
accum_dtype = torch.int32 if a_dtype == torch.int8 else torch.float32
golden = torch.matmul(a_cpu.to(accum_dtype), b_cpu.to(accum_dtype)).to(c_dtype)
a_tensor = a_tensor_cpu.to(f"npu:{device_id}")
b_tensor = b_tensor_cpu.to(f"npu:{device_id}")
c_tensor = torch.zeros(c_shape, dtype=c_dtype, device=f"npu:{device_id}")
return a_tensor, b_tensor, c_tensor, golden
def run_batch_matmul_test(case: dict):
device_id = int(os.environ.get("TILE_FWK_DEVICE_ID", 0))
torch.npu.set_device(device_id)
config = BatchMatmulConfig.from_test_case(case)
a_dtype = BatchMatmulConfig.get_torch_dtype(case["a_dtype"])
b_dtype = BatchMatmulConfig.get_torch_dtype(case["b_dtype"])
c_dtype = BatchMatmulConfig.get_torch_dtype(case["c_dtype"])
if config.dim == 3:
a_tensor, b_tensor, c_tensor, golden = prepare_tensors_3d(config, a_dtype, b_dtype, c_dtype, device_id)
batch_matmul_kernel_3d(a_tensor, b_tensor, c_tensor, config)
else:
a_tensor, b_tensor, c_tensor, golden = prepare_tensors_4d(config, a_dtype, b_dtype, c_dtype, device_id)
batch_matmul_kernel_4d(a_tensor, b_tensor, c_tensor, config)
atol, rtol = BatchMatmulConfig.get_tolerance(case["c_dtype"])
assert torch.allclose(
c_tensor.cpu(), golden.cpu(), atol=atol, rtol=rtol
), f"Test case {case['id']} ({case['name']}) failed"
@pytest.mark.parametrize("case", [
pytest.param(case, marks=pytest.mark.soc(*case["products"]))
for case in BASIC_3D_TESTS
])
def test_batch_matmul_3d_nd(case: dict):
run_batch_matmul_test(case)
@pytest.mark.parametrize("case", [
pytest.param(case, marks=pytest.mark.soc(*case["products"]))
for case in BASIC_4D_TESTS
])
def test_batch_matmul_4d_nd(case: dict):
run_batch_matmul_test(case)
def run_batch_matmul_demo():
b_size, m_size, k_size, n_size = 4, 256, 256, 256
m_view_size, n_view_size = 128, 128
b_view_size = 2
@pypto.frontend.jit(debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0})
def batch_matmul_demo_kernel(
a: pypto.Tensor([], pypto.DT_FP16),
b: pypto.Tensor([], pypto.DT_FP16),
out: pypto.Tensor([], pypto.DT_FP32),
):
pypto.set_cube_tile_shapes([128, 128], [128, 128], [128, 128])
pypto.set_vec_tile_shapes(128, 128)
pypto.set_matrix_size([m_size, k_size, n_size])
m_loop = (m_size + m_view_size - 1) // m_view_size
n_loop = (n_size + n_view_size - 1) // n_view_size
b_loop = (b_size + b_view_size - 1) // b_view_size
for b_idx in pypto.loop(0, b_loop, 1, name="LOOP_L0_bIdx", idx_name="b_idx"):
for m_idx in pypto.loop(0, m_loop, 1, name="LOOP_L0_mIdx", idx_name="m_idx"):
for n_idx in pypto.loop(0, n_loop, 1, name="LOOP_L1_nIdx", idx_name="n_idx"):
a_view = a[
b_idx * b_view_size: b_idx * b_view_size + b_view_size,
m_idx * m_view_size: m_idx * m_view_size + m_view_size,
:,
]
b_view = b[
b_idx * b_view_size: b_idx * b_view_size + b_view_size,
:,
n_idx * n_view_size: n_idx * n_view_size + n_view_size,
]
out_view = pypto.matmul(a_view, b_view, pypto.DT_FP32)
out[
b_idx * b_view_size: b_idx * b_view_size + b_view_size,
m_idx * m_view_size: m_idx * m_view_size + m_view_size,
n_idx * n_view_size: n_idx * n_view_size + n_view_size,
] = out_view
a = torch.randn([b_size, m_size, k_size], dtype=torch.float16, device="npu:0")
b = torch.randn([b_size, k_size, n_size], dtype=torch.float16, device="npu:0")
out = torch.empty(b_size, m_size, n_size, dtype=torch.float32, device="npu:0")
batch_matmul_demo_kernel(a, b, out)
if __name__ == "__main__":
run_batch_matmul_demo()