from typing import Tuple
import logging
import argparse
import torch
try:
import torch_npu
except ModuleNotFoundError:
pass
import asc
import asc.runtime.config as config
import asc.lib.runtime as rt
import asc.lib.host as host
USE_CORE_NUM = 24
IS_TRANS_A = False
IS_TRANS_B = False
ENABLE_BIAS = False
logging.basicConfig(level=logging.INFO)
@asc.jit(matmul_cube_only=True, always_compile=True)
def matmul_kernel(a: asc.GlobalAddress, b: asc.GlobalAddress, c: asc.GlobalAddress, bias: asc.GlobalAddress,
tiling: asc.adv.TCubeTiling, workspace: asc.GlobalAddress):
if asc.ascend_is_aic():
offset_a, offset_b, offset_c, offset_bias, tail_m, tail_n = calc_offsets(tiling, IS_TRANS_A, IS_TRANS_B)
a_global = asc.GlobalTensor()
b_global = asc.GlobalTensor()
c_global = asc.GlobalTensor()
bias_global = asc.GlobalTensor()
a_global.set_global_buffer(a + offset_a)
b_global.set_global_buffer(b + offset_b)
c_global.set_global_buffer(c + offset_c)
bias_global.set_global_buffer(bias + offset_bias)
pipe = asc.TPipe()
matmul = asc.adv.Matmul(
a=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, a_global.dtype, IS_TRANS_A),
b=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, b_global.dtype, IS_TRANS_B),
c=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, c_global.dtype),
bias=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, bias_global.dtype),
)
asc.adv.register_matmul(pipe, workspace, matmul, tiling)
if asc.get_block_idx() < tiling.used_core_num:
matmul.set_tensor_a(a_global, IS_TRANS_A)
matmul.set_tensor_b(b_global, IS_TRANS_B)
if tiling.is_bias:
matmul.set_bias(bias_global)
matmul.set_tail(tail_m, tail_n, tiling.k_a)
matmul.iterate_all(c_global)
matmul.end()
asc.pipe_barrier(asc.PipeID.PIPE_ALL)
@asc.jit
def calc_offsets(tiling: asc.adv.TCubeTiling, is_trans_a: bool = False,
is_trans_b: bool = False) -> Tuple[int, int, int, int, int, int]:
block_idx = asc.get_block_idx()
m_single_blocks = tiling.m.ceildiv(tiling.single_core_m)
n_index = block_idx // m_single_blocks
m_index = block_idx % m_single_blocks
offset_a = m_index * tiling.k_a * tiling.single_core_m
if is_trans_a:
offset_a = m_index * tiling.single_core_m
offset_b = n_index * tiling.single_core_n
if is_trans_b:
offset_b = n_index * tiling.k_b * tiling.single_core_n
offset_c = m_index * tiling.n * tiling.single_core_m + n_index * tiling.single_core_n
offset_bias = n_index * tiling.single_core_n
tail_m = tiling.m - m_index * tiling.single_core_m
if tail_m <= 0 or tail_m >= tiling.single_core_m:
tail_m = tiling.single_core_m
tail_n = tiling.n - n_index * tiling.single_core_n
if tail_n <= 0 or tail_n >= tiling.single_core_n:
tail_n = tiling.single_core_n
return offset_a, offset_b, offset_c, offset_bias, tail_m, tail_n
def matmul_launch(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor,
tiling: asc.adv.TCubeTiling, device) -> torch.Tensor:
size_m, size_k = a.shape
_, size_n = b.shape
c = torch.zeros((size_m, size_n), dtype=torch.float32, device=device)
workspace = torch.zeros(16 * 1024 * 1024, dtype=torch.uint8, device=device)
matmul_kernel[tiling.used_core_num, rt.current_stream()](a, b, c, bias, tiling, workspace)
return c
def generate_tiling(m, n, k) -> asc.adv.TCubeTiling:
matmul_tiling = host.MultiCoreMatmulTiling(host.get_ascendc_platform())
matmul_tiling.set_a_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_FLOAT16, IS_TRANS_A)
matmul_tiling.set_b_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_FLOAT16, IS_TRANS_B)
matmul_tiling.set_c_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_FLOAT)
matmul_tiling.set_bias_type(host.TPosition.GM, host.CubeFormat.ND, host.DataType.DT_FLOAT)
matmul_tiling.set_dim(USE_CORE_NUM)
matmul_tiling.set_org_shape(m, n, k)
matmul_tiling.set_shape(m, n, k)
matmul_tiling.enable_bias(ENABLE_BIAS)
matmul_tiling.set_buffer_space(-1, -1, -1)
tiling = asc.adv.TCubeTiling()
matmul_tiling.get_tiling(tiling)
return tiling
def matmul_cube_only_custom(backend: config.Backend, platform: config.Platform):
config.set_platform(backend, platform)
device = "npu" if config.Backend(backend) == config.Backend.NPU else "cpu"
m, k, n = 128, 64, 30720
a = torch.randint(-5, 5, (m, k), device=device).to(torch.float16)
b = torch.randint(-5, 5, (k, n), device=device).to(torch.float16)
bias = torch.randint(-5, 5, (1, n), device=device).to(torch.float32)
if ENABLE_BIAS:
matmul = (torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float32) + bias).to(torch.float32)
else:
matmul = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float32)
tiling = generate_tiling(m, n, k)
c = matmul_launch(a, b, bias, tiling, device)
assert torch.allclose(c, matmul, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-r", type=str, default="Model", help="backend to run")
parser.add_argument("-v", type=str, default=None, help="platform to run")
args = parser.parse_args()
backend = args.r
platform = args.v
if backend not in config.Backend.__members__:
raise ValueError("Unsupported Backend! Supported: ['Model', 'NPU']")
backend = config.Backend(backend)
if platform is not None:
platform_values = [platform.value for platform in config.Platform]
if platform not in platform_values:
raise ValueError(f"Unsupported Platform! Supported: {platform_values}")
platform = config.Platform(platform)
logging.info("[INFO] start process sample matmul_cube_only.")
matmul_cube_only_custom(backend, platform)
logging.info("[INFO] Sample matmul_cube_only run success.")