from typing import Tuple
import pytest
import asc
from asc.runtime import config
import asc.lib.runtime as rt
import asc.lib.host as host
try:
import torch
except ModuleNotFoundError:
pytest.skip("torch is not installed", allow_module_level=True)
@asc.jit
def matmul_kernel(a: asc.GlobalAddress, b: asc.GlobalAddress, c: asc.GlobalAddress, tiling: asc.adv.TCubeTiling,
workspace: asc.GlobalAddress):
tiling.share_l1_size = asc.property(asc.TOTAL_L1_SIZE)
tiling.share_l0c_size = asc.property(asc.TOTAL_L0C_SIZE)
offset_a, offset_b, offset_c = calc_offsets(tiling)
a_global = asc.GlobalTensor()
b_global = asc.GlobalTensor()
c_global = asc.GlobalTensor()
a_global.set_global_buffer(a + offset_a)
b_global.set_global_buffer(b + offset_b)
c_global.set_global_buffer(c + offset_c)
pipe = asc.TPipe()
matmul = asc.adv.Matmul(
a=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, a_global.dtype),
b=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, b_global.dtype),
c=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, c_global.dtype),
)
asc.adv.register_matmul(pipe, workspace, matmul, tiling)
matmul.set_tensor_a(a_global)
matmul.set_tensor_b(b_global)
matmul.iterate_all(c_global)
matmul.end()
asc.pipe_barrier(asc.PipeID.PIPE_ALL)
@asc.jit
def calc_offsets(tiling: asc.adv.TCubeTiling) -> Tuple[int, int, int]:
block_idx = asc.get_block_idx()
temp0 = tiling.m.ceildiv(tiling.single_core_m)
temp1 = tiling.k_a.ceildiv(tiling.single_core_k)
temp2 = tiling.used_core_num // temp1
m_index = (block_idx % temp2) % temp0
n_index = (block_idx % temp2) // temp0
offset_a = m_index * tiling.k_a * tiling.single_core_m
offset_b = n_index * tiling.single_core_n
offset_c = m_index * tiling.n * tiling.single_core_m + n_index * tiling.single_core_n
return offset_a, offset_b, offset_c
def matmul_launch(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor,
workspace: torch.Tensor, tiling: asc.adv.TCubeTiling) -> torch.Tensor:
assert a.shape[1] == b.shape[0], "Matrices must be compatible for a multiplication"
matmul_kernel[tiling.used_core_num, rt.current_stream()](a, b, c, tiling, workspace)
return c
def generate_tiling(m, n, k, dtype):
matmul_tiling = host.MultiCoreMatmulTiling(host.get_ascendc_platform())
host_dtype = host.DataType.DT_FLOAT if dtype == torch.float32 else host.DataType.DT_FLOAT16
matmul_tiling.set_a_type(host.TPosition.GM, host.CubeFormat.ND, host_dtype, False)
matmul_tiling.set_b_type(host.TPosition.GM, host.CubeFormat.ND, host_dtype, False)
matmul_tiling.set_c_type(host.TPosition.VECCALC, host.CubeFormat.ND, host.DataType.DT_FLOAT)
matmul_tiling.set_dim(16)
matmul_tiling.set_org_shape(m, n, k)
matmul_tiling.set_shape(m, n, k)
matmul_tiling.set_traverse(host.MatrixTraverse.FIRSTM)
matmul_tiling.set_buffer_space(-1, -1, -1)
tiling = asc.adv.TCubeTiling()
matmul_tiling.get_tiling(tiling)
return tiling
param_list = [
[torch.float16, (256, 64, 128)],
]
BACKENDS = [
config.Backend.NPU,
]
@pytest.mark.parametrize("dtype, size", param_list)
@pytest.mark.parametrize("backend", BACKENDS)
def test_matmul_iterate_all(dtype, size, backend: config.Backend):
config.set_platform(backend)
device = "npu" if config.Backend(backend) == config.Backend.NPU else "cpu"
m, k, n = size
if dtype in {torch.float16, torch.float32}:
a = (torch.rand((m, k), dtype=dtype, device=device) - 0.5) * 10
b = (torch.rand((k, n), dtype=dtype, device=device) - 0.5) * 10
c = torch.zeros((m, n), dtype=torch.float32, device=device)
else:
a = torch.randint(-5, 5, (m, k), dtype=dtype, device=device)
b = torch.randint(-5, 5, (k, n), dtype=dtype, device=device)
c = torch.zeros((m, n), dtype=torch.int32, device=device)
workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device)
matmul = a.to(torch.float32) @ b.to(torch.float32)
if dtype == torch.int8:
matmul = a.to(torch.int32) @ b.to(torch.int32)
tiling = generate_tiling(m, n, k, dtype)
c = matmul_launch(a, b, c, workspace, tiling)
assert torch.allclose(c, matmul, atol=1e-3)