from typing import Tuple
import numpy as np
import asc
import asc.runtime.config as config
import asc.lib.runtime as rt
@asc.jit(insert_sync=True)
def matmul_kernel(a: asc.GlobalAddress, b: asc.GlobalAddress, c: asc.GlobalAddress, alpha: float,
tiling: asc.adv.TCubeTiling, workspace: asc.GlobalAddress, bias: asc.GlobalAddress,
blocksize: asc.ConstExpr):
tiling.share_l1_size = asc.property(asc.TOTAL_L1_SIZE)
tiling.share_l0c_size = asc.property(asc.TOTAL_L0C_SIZE)
offset_a, offset_b, offset_c, offset_bias = calc_offsets(tiling)
a_global = asc.GlobalTensor()
b_global = asc.GlobalTensor()
c_global = asc.GlobalTensor()
bias_global = asc.GlobalTensor()
a_global.set_global_buffer(a + offset_a)
b_global.set_global_buffer(b + offset_b)
c_global.set_global_buffer(c + offset_c)
bias_global.set_global_buffer(bias + offset_bias)
pipe = asc.TPipe()
matmul = asc.adv.Matmul(
a=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, a_global.dtype),
b=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, b_global.dtype),
c=asc.adv.MatmulType(asc.TPosition.VECCALC, asc.CubeFormat.ND, c_global.dtype),
bias=asc.adv.MatmulType(asc.TPosition.GM, asc.CubeFormat.ND, bias_global.dtype),
matmul_config=asc.adv.MatmulConfig()
)
asc.adv.register_matmul(pipe, workspace, matmul)
matmul.init(tiling)
matmul.set_tensor_a(a_global)
matmul.set_tensor_b(b_global)
matmul.set_bias(bias_global)
with matmul.iterate() as count:
relu_out_local = asc.LocalTensorAuto(c.dtype, blocksize)
matmul.get_tensor_c(relu_out_local, en_sequential_write=True)
asc.leaky_relu(relu_out_local, relu_out_local, alpha, count=tiling.base_m * tiling.base_n)
round_m = tiling.single_core_m // tiling.base_m
start_offset = count % round_m * tiling.base_m * tiling.n + count // round_m * tiling.base_n
params = asc.DataCopyParams(
block_count=tiling.base_m,
block_len=tiling.base_n * c.dtype.sizeof() // asc.property(asc.DEFAULT_C0_SIZE),
src_stride=0,
dst_stride=(tiling.n - tiling.base_n) * c.dtype.sizeof() // asc.property(asc.DEFAULT_C0_SIZE),
)
asc.data_copy(c_global[start_offset:], relu_out_local, repeat_params=params)
matmul.end()
asc.pipe_barrier(asc.PipeID.PIPE_ALL)
@asc.jit
def calc_offsets(tiling: asc.adv.TCubeTiling) -> Tuple[int, int, int]:
block_idx = asc.get_block_idx()
temp0 = tiling.m.ceildiv(tiling.single_core_m)
temp1 = tiling.k_a.ceildiv(tiling.single_core_k)
temp2 = tiling.used_core_num // temp1
m_index = (block_idx % temp2) % temp0
n_index = (block_idx % temp2) // temp0
offset_a = m_index * tiling.k_a * tiling.single_core_m
offset_b = n_index * tiling.single_core_n
offset_c = m_index * tiling.n * tiling.single_core_m + n_index * tiling.single_core_n
offset_bias = n_index * tiling.single_core_n
return offset_a, offset_b, offset_c, offset_bias
def matmul_launch(a: np.ndarray, b: np.ndarray, bias: np.ndarray, alpha: float) -> np.ndarray:
assert a.shape[1] == b.shape[0], "Matrices must be compatible for a multiplication"
size_m, size_k = a.shape
_, size_n = b.shape
c = np.empty((size_m, size_n), dtype=a.dtype)
blocksize_m = 64
blocksize_n = 64
blocksize_k = 64
single_m = blocksize_m
single_n = blocksize_n
single_k = size_k
used_core_num = asc.ceildiv(size_m, single_m) * asc.ceildiv(size_n, single_n)
tiling = asc.adv.TCubeTiling(used_core_num=used_core_num, m=size_m, k_a=size_k, k_b=size_k, n=size_n,
base_m=blocksize_m, base_k=blocksize_k, base_n=blocksize_n, single_core_m=single_m,
single_core_k=single_k, single_core_n=single_n, depth_a1=1, depth_b1=1, step_m=1,
step_n=1, share_mode=0, is_bias=1)
ws = np.zeros(16 * 1024 * 1024, dtype="uint8")
matmul_kernel[used_core_num, rt.current_stream()](a, b, c, alpha, tiling, ws, bias, blocksize_m * blocksize_n)
return c
def test_matmul(backend: config.Backend):
config.set_platform(backend)
rng = np.random.default_rng(seed=2025)
m, k, n = 256, 256, 256
a = (rng.random((m, k), dtype="float32") - .5) * 10
b = (rng.random((k, n), dtype="float32") - .5) * 10
bias = (rng.random((1, n), dtype="float32") - .5) * 10
alpha = 0.01
matmul = a @ b + bias
c = matmul_launch(a, b, bias, alpha)
np.testing.assert_allclose(c, np.where(matmul >= 0, matmul, matmul * alpha), atol=1e-3)
if __name__ == "__main__":
test_matmul(config.Backend.Model)