import unittest
import torch
import torch_npu
import numpy as np
import torch.nn.functional as F
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
MIN_ERR = 1e-7
EB_THRESHOLD = 2**(-8)
ERR_THRESHOLD = 2**(-8)
def get_mare(golden: torch.Tensor, actual: torch.Tensor):
golden = golden.to(torch.float32)
abs_error = torch.abs(actual.to(torch.float32) - golden) / (torch.abs(golden) + MIN_ERR)
mare = torch.max(abs_error.flatten())
return mare
def get_mere(golden: torch.Tensor, actual: torch.Tensor):
golden = golden.to(torch.float32)
abs_error = torch.abs(actual.to(torch.float32) - golden) / (torch.abs(golden) + MIN_ERR)
mere = torch.mean(abs_error)
return mere
def get_rmse(golden: torch.Tensor, actual: torch.Tensor):
golden = golden.to(torch.float32)
sqr_err = torch.pow((actual.to(torch.float32) - golden), 2)
rmse = torch.sqrt(torch.mean(sqr_err))
return rmse
def compare_cv(golden: torch.Tensor, golden_high_type: torch.Tensor, actual: torch.Tensor):
mare_npu = get_mare(golden, actual)
mare_high_type = get_mare(golden, golden_high_type)
mere_npu = get_mere(golden, actual)
mere_high_type = get_mere(golden, golden_high_type)
rmse_npu = get_rmse(golden, actual)
rmse_high_type = get_rmse(golden, golden_high_type)
mare_rate = mare_npu / max(mare_high_type, ERR_THRESHOLD)
mere_rate = mere_npu / max(mere_high_type, ERR_THRESHOLD)
rmse_rate = rmse_npu / max(rmse_high_type, ERR_THRESHOLD)
result = (mare_rate < 10) and (mere_rate < 2) and (rmse_rate < 2)
return result
class TestGroupedMatmulAdd(TestCase):
@unittest.skip('Skip until cann package after 20240813 is used.')
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_reduce_sum_01(self):
B, M, K, N = 8, 2048, 1024, 7168
x_nd = torch.randint(-10, 10, (B, M, K), dtype=torch.int8)
w_nd = torch.randint(-10, 10, (B, K, N), dtype=torch.int8)
scale = torch.rand((N,), dtype=torch.bfloat16)
pertoken_scale = torch.rand((B, M), dtype=torch.float32)
y_dtype = torch.bfloat16
golden = torch.bmm(x_nd.float(), w_nd.float()).to(torch.float32)
golden = scale[None, None, :] * golden
golden = pertoken_scale[:, :, None] * golden
golden_fp32 = torch.sum(golden, dim=0).to(torch.bfloat16)
torch.use_deterministic_algorithms(True)
golden_bf16 = torch.zeros(M, N, dtype=torch.bfloat16)
for i in range(B):
golden_bf16 += golden[i, ...].to(torch.bfloat16)
device = torch.device("npu:0")
x_nd_npu = x_nd.to(device)
w_nd_npu = w_nd.to(device)
w_nz_npu = torch_npu.npu_format_cast(w_nd_npu.contiguous(), 29)
x2_scale = scale.to(device)
x1_scale = pertoken_scale.to(device)
custom_out = torch_npu.npu_quant_matmul_reduce_sum(x_nd_npu, w_nz_npu, x1_scale=x1_scale, x2_scale=x2_scale)
self.assertTrue(compare_cv(golden_fp32, golden_bf16, custom_out.cpu()))
@unittest.skip('Skip until cann package after 20240813 is used.')
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_reduce_sum_02(self):
B, M, K, N = 16, 1024, 1024, 7168
x_nd = torch.randint(-10, 10, (B, M, K), dtype=torch.int8)
w_nd = torch.randint(-10, 10, (B, K, N), dtype=torch.int8)
scale = torch.rand((N,), dtype=torch.bfloat16)
pertoken_scale = torch.rand((B, M), dtype=torch.float32)
golden = torch.bmm(x_nd.float(), w_nd.float()).to(torch.float32)
golden = scale[None, None, :] * golden
golden = pertoken_scale[:, :, None] * golden
golden_fp32 = torch.sum(golden, dim=0).to(torch.bfloat16)
torch.use_deterministic_algorithms(True)
golden_bf16 = torch.zeros(M, N, dtype=torch.bfloat16)
for i in range(B):
golden_bf16 += golden[i, ...].to(torch.bfloat16)
device = torch.device("npu:0")
x_nd_npu = x_nd.to(device)
w_nd_npu = w_nd.to(device)
w_nz_npu = torch_npu.npu_format_cast(w_nd_npu.contiguous(), 29)
x2_scale = scale.to(device)
x1_scale = pertoken_scale.to(device)
custom_out = torch_npu.npu_quant_matmul_reduce_sum(x_nd_npu, w_nz_npu, x1_scale=x1_scale, x2_scale=x2_scale)
self.assertTrue(compare_cv(golden_fp32, golden_bf16, custom_out.cpu()))
if __name__ == "__main__":
run_tests()