import random
from typing import Union
import unittest
import torch
import torch_catlass
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
def generate_sequence_split(n, sum_num):
split_points = sorted([random.randint(0, sum_num) for _ in range(n - 1)])
split_points = [0] + split_points + [sum_num]
sequence = [split_points[i + 1] - split_points[i] for i in range(n)]
sequence[-1] = sum_num - sum(sequence[:-1])
return sequence
def calculate_prefix_sum(sequence):
prefix_sum = []
current_sum = 0
for num in sequence:
current_sum += num
prefix_sum.append(current_sum)
return prefix_sum
class CatlassTest(TestCase):
def _run_case_basic(
self,
m: int,
n: int,
k: int,
trans_a: bool = False,
trans_b: bool = False,
dtype: torch.dtype = torch.float16,
):
shape1 = (m, k) if not trans_a else (k, m)
shape2 = (k, n) if not trans_b else (n, k)
a = torch.rand(shape1, device="npu").to(dtype)
b = torch.rand(shape2, device="npu").to(dtype)
a = a if not trans_a else a.T
b = b if not trans_b else b.T
torch.npu.synchronize()
result = torch_catlass.basic_matmul(a, b, str(dtype).split(".")[-1])
golden = torch.mm(a, b)
torch.npu.synchronize()
if dtype == torch.bfloat16:
result = result.to(torch.float32)
golden = golden.to(torch.float32)
self.assertRtolEqual(result, golden)
def test_basic_matmul_pybind(self):
self._run_case_basic(2, 3, 4)
@unittest.skip("Not ready")
def test_basic_matmul_pybind_cr(self):
self._run_case_basic(2, 3, 4, trans_a=True)
@unittest.skip("Not ready")
def test_basic_matmul_pybind_rc(self):
self._run_case_basic(2, 3, 4, trans_b=True)
@unittest.skip("Not ready")
def test_basic_matmul_pybind_cc(self):
self._run_case_basic(2, 3, 4, trans_a=True, trans_b=True)
@unittest.skip("Not ready")
def test_basic_matmul_pybind_bf16(self):
self._run_case_basic(2, 3, 4, trans_a=True, trans_b=True)
def __grouped_matmul_golden_split_m(
self,
a: torch.Tensor,
b: torch.Tensor,
group_list: Union[torch.Tensor, list[int]],
dtype: str,
transpose_a: bool = False,
transpose_b: bool = False,
):
self.assertEqual(transpose_a, False)
if isinstance(group_list, list):
a_list = torch.split(a, group_list)
result_list = []
for i in range(len(group_list)):
a_g = a_list[i] if not transpose_a else a_list[i].T
b_g = b[i] if not transpose_b else b[i].T
result_list.append(a_g @ b_g)
return torch.cat(result_list).to(eval(f"torch.{dtype}"))
def __grouped_matmul_golden_split_k(
self,
a: torch.Tensor,
b: torch.Tensor,
group_list: Union[torch.Tensor, list[int]],
dtype: str,
transpose_a: bool = True,
transpose_b: bool = False,
):
self.assertEqual(transpose_a and not transpose_b, True)
assert transpose_a and not transpose_b
if isinstance(group_list, list):
a_list = torch.split(a, group_list)
b_list = torch.split(b, group_list)
result_list = []
for i in range(len(group_list)):
a_g = a_list[i].T
b_g = b_list[i]
result_list.append(a_g @ b_g)
return torch.stack(result_list).to(eval(f"torch.{dtype}"))
def __grouped_matmul_golden(
self,
a: torch.Tensor,
b: torch.Tensor,
group_list: Union[torch.Tensor, list[int]],
dtype: str,
transpose_a: bool,
transpose_b: bool,
split_k: bool,
):
if not split_k:
return self.__grouped_matmul_golden_split_m(
a, b, group_list, dtype, transpose_a, transpose_b
)
else:
return self.__grouped_matmul_golden_split_k(
a, b, group_list, dtype, transpose_a, transpose_b
)
def test_grouped_matmul_list_m(self):
g = 128
group_list = generate_sequence_split(g, random.randint(256, 4096))
group_list_prefix_sum = calculate_prefix_sum(group_list)
m_sum, k, n = group_list_prefix_sum[-1], 256, 256
a = torch.randn((m_sum, k), device="npu").to(torch.float16)
b = torch.randn((g, k, n), device="npu").to(torch.float16)
group_list_prefix_sum_tensor = torch.tensor(
group_list_prefix_sum, device="npu"
).to(torch.int64)
result = torch_catlass.grouped_matmul(
a, b, group_list_prefix_sum_tensor, "float16", False, True, False
)
golden = self.__grouped_matmul_golden(
a, b, group_list, "float16", False, True, False
)
self.assertRtolEqual(result, golden)
def test_grouped_matmul_list_k(self):
g = 128
group_list = generate_sequence_split(g, random.randint(256, 4096))
group_list_prefix_sum = calculate_prefix_sum(group_list)
m, k_sum, n = 256, group_list_prefix_sum[-1], 256
a = torch.randn((k_sum, m), device="npu").to(torch.float16)
b = torch.randn((k_sum, n), device="npu").to(torch.float16)
group_list_prefix_sum_tensor = torch.tensor(
group_list_prefix_sum, device="npu"
).to(torch.int64)
result = torch_catlass.grouped_matmul(
a, b, group_list_prefix_sum_tensor, "float16", True, False, True
)
golden = self.__grouped_matmul_golden(
a, b, group_list, "float16", True, False, True
)
self.assertRtolEqual(result, golden)
def test_optimized_matmul_pybind(self):
a = torch.rand((2, 3), device="npu").to(torch.float16)
b = torch.rand((3, 4), device="npu").to(torch.float16)
torch.npu.synchronize()
result = torch_catlass.optimized_matmul(a, b, "float16")
golden = torch.mm(a, b)
torch.npu.synchronize()
self.assertRtolEqual(result, golden)
if __name__ == "__main__":
run_tests()