import math
import unittest
import copy
import struct
from struct import pack, unpack
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
def quantize_w_per_group_asymmetric(w_fp32: torch.Tensor, group_size: int = 256):
K, N = w_fp32.shape
q_min, q_max = -7, 7
w_fp32_grouped = w_fp32.view(-1, group_size, N)
min_vals = torch.min(w_fp32_grouped, dim=1, keepdim=True)[0]
max_vals = torch.max(w_fp32_grouped, dim=1, keepdim=True)[0]
range_vals = max_vals - min_vals
range_vals[range_vals == 0] = 1e-5
w_scales = range_vals / (q_max - q_min)
w_zeros_fp = torch.round(q_min - min_vals / w_scales)
w_quant_fp = torch.round(w_fp32_grouped / w_scales) + w_zeros_fp
w_quant_grouped = torch.clamp(w_quant_fp, q_min, q_max).to(torch.int8)
w_quant = w_quant_grouped.view(K, N)
w_scales = w_scales.squeeze(1) / 15.0
w_zeros = w_zeros_fp.squeeze(1).to(torch.int8)
return w_quant, w_scales, w_zeros
def unpack_int4_lsb_first_signed(w_packed: torch.Tensor) -> torch.Tensor:
assert w_packed.dtype == torch.int32, "Input tensor must be of dtype torch.int32"
N, M = w_packed.shape
shifts = torch.arange(0, 32, 4, device=w_packed.device).view(1, 1, 8)
unpacked = (w_packed.unsqueeze(-1).cpu() >> shifts.cpu()) & 0xF
unpacked = unpacked.to(torch.int8)
unpacked_unsigned = unpacked.view(N, -1).to(w_packed.device)
unpacked_signed = torch.where(unpacked_unsigned > 7, unpacked_unsigned - 16, unpacked_unsigned)
return unpacked_signed.to(torch.int8)
def cpu_quant_matmul_a4w4_pergroup(x, weight, x1scale, x2scale, x2offset, groupsize=256):
x = unpack_int4_lsb_first_signed(x).to(torch.int32)
weight = unpack_int4_lsb_first_signed(weight).T.to(torch.int32)
m, k = x.shape
k, n = weight.shape
golden_tensor = torch.zeros((m, n), dtype=torch.float32, device=x.device)
for kstart in range(0, k, groupsize):
kend = kstart + groupsize if kstart + groupsize < k else k
subBlockX = x[:, kstart:kend]
subBlockWeight = weight[kstart:kend, :]
subBlockWeight = subBlockWeight - x2offset[int(kstart//groupsize)].int()
mm_result = torch.matmul(subBlockX.cpu(), subBlockWeight.cpu()).float().npu()
mm_result *= x2scale[int(kstart//groupsize)]
golden_tensor += mm_result
return golden_tensor.float() * x1scale
class TestQuantMatmul(TestCase):
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_quant_matmul_a8w8(self):
torch.manual_seed(0)
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int8)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
scale_quant = torch_npu.npu_trans_quant_param(scale.npu(), None)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_quant.npu(), output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_a4w4(self):
torch.manual_seed(0)
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int32)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int32)
x1_clone = x1.clone().float().npu()
x2_clone = x2.clone().float().npu()
scale = torch.randn(1, dtype=torch.float32)
scale_clone = scale.clone().npu()
scale_quant = torch_npu.npu_trans_quant_param(scale_clone, None)
scale_tmp = scale.clone().npu()
scale_tmp[0] = 1
x1_clone = torch_npu.npu_quantize(x1_clone, scale_tmp, None, torch.quint4x2, -1, False)
x2_clone = torch_npu.npu_quantize(x2_clone, scale_tmp, None, torch.quint4x2, -1, False)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone, x2_clone, scale_quant, output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_a4w4_pertoken(self):
torch.manual_seed(0)
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int32)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int32)
x1_clone = x1.clone().float().npu()
x2_clone = x2.clone().float().npu()
scale = torch.randn(1, dtype=torch.float32)
pertoken_scale = torch.randn(8192, dtype=torch.float32)
scale_clone = scale.clone().npu()
pertoken_scale_clone = pertoken_scale.clone().npu()
scale_tmp = scale.clone().npu()
scale_tmp[0] = 1
x1_clone = torch_npu.npu_quantize(x1_clone, scale_tmp, None, torch.quint4x2, -1, False)
x2_clone = torch_npu.npu_quantize(x2_clone, scale_tmp, None, torch.quint4x2, -1, False)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale * pertoken_scale.unsqueeze(-1)
custom_output_bf16 = torch_npu.npu_quant_matmul(x1_clone, x2_clone, scale_clone, pertoken_scale=pertoken_scale_clone, output_dtype=torch.bfloat16)
custom_output_fp16 = torch_npu.npu_quant_matmul(x1_clone, x2_clone, scale_clone, pertoken_scale=pertoken_scale_clone, output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output_bf16.float().cpu().numpy(), 0.01)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output_fp16.float().cpu().numpy(), 0.01)
def gen_inputs_for_npu_quant_matmul_a4w4_pergroup(self, dtype):
x1 = torch.randint(-5, 5, (256, 2048), dtype=dtype, device="npu")
x2 = torch.randint(-5, 5, (5120, 2048), dtype=dtype, device="npu")
x1, x1_scale = torch_npu.npu_dynamic_quant(x1, dst_type=torch.quint4x2)
x2, x2_scale, x2_offset = quantize_w_per_group_asymmetric(x2.transpose(-1, -2))
x2_offset = x2_offset.to(torch.float16)
x2_scale = x2_scale.to(torch.float32)
x2 = torch_npu.npu_convert_weight_to_int4pack(x2.to(torch.int32).transpose(-1, -2).contiguous()).transpose(-1, -2)
return x1, x2, x1_scale.reshape(-1, 1), x2_scale, x2_offset
@unittest.skip("skip test_npu_quant_matmul_a4w4_pergroup")
def test_npu_quant_matmul_a4w4_pergroup(self):
torch.manual_seed(0)
x1, x2, x1_scale, x2_scale, x2_offset = self.gen_inputs_for_npu_quant_matmul_a4w4_pergroup(torch.bfloat16)
supported_output_bf16 = cpu_quant_matmul_a4w4_pergroup(x1, x2.T, x1_scale, x2_scale, x2_offset)
custom_output_bf16 = torch_npu.npu_quant_matmul(x1, x2, scale=x2_scale, offset=x2_offset, pertoken_scale=x1_scale, bias=None, output_dtype=torch.bfloat16, group_sizes=[1, 1, 256])
self.assertRtolEqual(custom_output_bf16.float().cpu().numpy(), supported_output_bf16.float().cpu().numpy(), 0.01)
x1, x2, x1_scale, x2_scale, x2_offset = self.gen_inputs_for_npu_quant_matmul_a4w4_pergroup(torch.float16)
supported_output_fp16 = cpu_quant_matmul_a4w4_pergroup(x1, x2.T, x1_scale, x2_scale, x2_offset)
custom_output_fp16 = torch_npu.npu_quant_matmul(x1, x2, scale=x2_scale, offset=x2_offset, pertoken_scale=x1_scale, bias=None, output_dtype=torch.float16, group_sizes=[1, 1, 256])
self.assertRtolEqual(custom_output_fp16.float().cpu().numpy(), supported_output_fp16.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_continuous_x2_tensor(self):
torch.manual_seed(0)
x1 = torch.randint(-5, 5, (5, 160), dtype=torch.int32)
x2 = torch.randint(-5, 5, (80, 160), dtype=torch.int32)
x1_clone = x1.clone().float().npu()
x2_clone = x2.clone().float().npu()
scale = torch.randn(1, dtype=torch.float32)
scale_clone = scale.clone().npu()
scale_quant = torch_npu.npu_trans_quant_param(scale_clone, None)
scale_tmp = scale.clone().npu()
scale_tmp[0] = 1
x1_clone = torch_npu.npu_quantize(x1_clone, scale_tmp, None, torch.quint4x2, -1, False)
x2_clone = torch_npu.npu_quantize(x2_clone, scale_tmp, None, torch.quint4x2, -1, False)
supported_output = torch.matmul(x1.to(torch.int32), x2.t().to(torch.int32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone, x2_clone.t(), scale_quant, output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_quant_matmul_bf16(self):
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int8)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale.npu(), output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_quant_matmul_bf16_nz(self):
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int8)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale
x2_nz = torch_npu.npu_format_cast(x2_clone.npu().contiguous(), 29)
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_nz.npu(), scale.npu(), output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_quant_matmul_fp16_nz(self):
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.int8)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale
x2_nz = torch_npu.npu_format_cast(x2_clone.npu().contiguous(), 29)
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_nz.npu(), scale.npu(), output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.01)
@unittest.skip("skip test_npu_quant_matmul_int32")
def test_npu_quant_matmul_int32(self):
x1 = torch.randint(-5, 5, (16, 6656), dtype=torch.int8)
x2 = torch.randint(-5, 5, (6656, 4992), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale.npu(), output_dtype=torch.int32)
self.assertRtolEqual(supported_output.cpu().numpy().astype(np.float32), custom_output.cpu().numpy().astype(np.float32), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_fp8_pertensor(self):
x1 = torch.randint(-5, 5, (16, 38), dtype=torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (16, 112), dtype=torch.float8_e5m2)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
scale_clone = scale_clone.clone()
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu().t(), x2_clone.npu(), scale_clone.npu(), output_dtype=torch.float32)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.0001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_hif8_perchannel(self):
x1 = torch.randint(0, 255, (224, 64), dtype=torch.int8)
x2 = torch.randint(0, 255, (64, 8360), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
x1 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x1.numpy()))
x2 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x2.numpy()))
scale = torch.randn(8360, dtype=torch.float32)
scale_clone = scale_clone.clone()
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_int8_pertoken(self):
x1 = torch.randint(-5, 5, (4183, 1088), dtype=torch.int8)
x2 = torch.randint(-5, 5, (960, 1088), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
pertoken_scale = torch.randn(4183, dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
supported_output = torch.matmul(x1.to(torch.int32), x2.to(torch.int32)) * scale * pertoken_scale.unsqueeze(-1)
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_fp8_pertoken(self):
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.float8_e4m3fn)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(2560, dtype=torch.float32)
pertoken_scale = torch.randn(4183, dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * scale * pertoken_scale.unsqueeze(-1)
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_hif8_pertoken(self):
x1 = torch.randint(0, 255, (8192, 320), dtype=torch.uint8)
x2 = torch.randint(0, 255, (320, 2560), dtype=torch.uint8)
x1_clone = x1.clone()
x2_clone = x2.clone()
x1 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x1.numpy()))
x2 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x2.numpy()))
scale = torch.randn(2560, dtype=torch.float32)
pertoken_scale = torch.randn(4183, dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * scale * pertoken_scale.unsqueeze(-1)
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8, output_dtype=torch.float32)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.0001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_fp8_doublescale(self):
x1 = torch.randint(-5, 5, (544, 236), dtype=torch.float8_e5m2)
x2 = torch.randint(-5, 5, (544, 568), dtype=torch.float8_e5m2)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randn(1, dtype=torch.float32)
pertoken_scale = torch.randn(1, dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
double_scale = scale * pertoken_scale
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * double_scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu().t(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), output_dtype=torch.float32)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.0001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_hif8_doublescale(self):
x1 = torch.randint(0, 255, (30, 32), dtype=torch.uint8)
x2 = torch.randint(0, 255, (32, 8), dtype=torch.uint8)
x1_clone = x1.clone()
x2_clone = x2.clone()
x1 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x1.numpy()))
x2 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x2.numpy()))
scale = torch.randn(1, dtype=torch.float32)
pertoken_scale = torch.randn(1, dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
double_scale = scale * pertoken_scale
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * double_scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp8_odd_scaleK(self):
x1 = torch.randint(-5, 5, (185, 480), dtype=torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (1880, 480), dtype=torch.float8_e4m3fn)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randint(-5, 5, (1880, 16), dtype=torch.int8)
pertoken_scale = torch.randint(-5, 5, (185, 16), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy().transpose()
pertoken_scale = pertoken_scale.numpy()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.float16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp8_even_scaleK(self):
x1 = torch.randint(-5, 5, (112, 944), dtype=torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (184, 944), dtype=torch.float8_e4m3fn)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randint(-5, 5, (184, 30), dtype=torch.int8)
pertoken_scale = torch.randint(-5, 5, (112, 30), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy().transpose()
pertoken_scale = pertoken_scale.numpy()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp8_with_batch(self):
x1 = torch.randint(-5, 5, (1, 112, 944), dtype=torch.float8_e4m3fn)
x2 = torch.randint(-5, 5, (184, 944), dtype=torch.float8_e4m3fn)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randint(-5, 5, (184, 30), dtype=torch.int8)
pertoken_scale = torch.randint(-5, 5, (112, 30), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy().reshape(112, 944)
x2 = x2.numpy()
scale = scale.numpy().transpose()
pertoken_scale = pertoken_scale.numpy()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp4(self):
x1 = torch.randint(-5, 5, (112, 944), dtype=torch.int8)
x2 = torch.randint(-5, 5, (184, 944), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randint(-5, 5, (184, 30), dtype=torch.int8)
pertoken_scale = torch.randint(-5, 5, (112, 30), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy().transpose()
pertoken_scale = pertoken_scale.numpy()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.float4_e2m1fn_x2, x2_dtype=torch_npu.float4_e2m1fn_x2, scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp4_x1trans(self):
x1 = torch.randint(-5, 5, (944, 112), dtype=torch.int8)
x2 = torch.randint(-5, 5, (944, 184), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
pertoken_scale = torch.randint(-5, 5, (30, 112), dtype=torch.int8)
scale = torch.randint(-5, 5, (30, 184), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy()
pertoken_scale = pertoken_scale.numpy().transpose()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu().t(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.float4_e2m1fn_x2, x2_dtype=torch_npu.float4_e2m1fn_x2, scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp4_x1trans_x2trans(self):
x1 = torch.randint(-5, 5, (944, 112), dtype=torch.int8)
x2 = torch.randint(-5, 5, (184, 944), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
pertoken_scale = torch.randint(-5, 5, (30, 112), dtype=torch.int8)
scale = torch.randint(-5, 5, (184, 30), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy().transpose()
pertoken_scale = pertoken_scale.numpy().transpose()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu().t(), x2_clone.npu().t(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.float4_e2m1fn_x2, x2_dtype=torch_npu.float4_e2m1fn_x2, scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_mxfp4_notrans(self):
x1 = torch.randint(-5, 5, (112, 944), dtype=torch.int8)
x2 = torch.randint(-5, 5, (944, 184), dtype=torch.int8)
x1_clone = x1.clone()
x2_clone = x2.clone()
pertoken_scale = torch.randint(-5, 5, (112, 30), dtype=torch.int8)
scale = torch.randint(-5, 5, (30, 184), dtype=torch.int8)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
x1 = x1.numpy()
x2 = x2.numpy()
scale = scale.numpy()
pertoken_scale = pertoken_scale.numpy()
scale = scale[:-1, :]
pertoken_scale = pertoken_scale[:, :-1]
scale_broadcast = np.repeat(scale, 32, axis=-1)
pertoken_scale_broadcast = np.repeat(pertoken_scale, 32, axis=-2)
x1_pad_len = pertoken_scale_broadcast.shape[-1] - x1.shape[-1]
x2_pad_len = scale_broadcast[-2] - x2.shape[-2]
x1_dims = len(x1.shape)
x2_dims = len(x2.shape)
x1 = np.pad(x1, [(0, 0)] * (x1_dims - 1) + [(0, x1_pad_len)], mode='constant', constant_values=0)
x2 = np.pad(x2, [(0, 0)] * (x2_dims - 2) + [(0, x2_pad_len)] + [(0, 0)], mode='constant', constant_values=0)
x1 = x1 * pertoken_scale_broadcast
x2 = x2 * scale_broadcast
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.float4_e2m1fn_x2, x2_dtype=torch_npu.float4_e2m1fn_x2, scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_hif8_perblock(self):
x1 = torch.randint(0, 255, (8192, 320), dtype=torch.uint8)
x2 = torch.randint(0, 255, (320, 2560), dtype=torch.uint8)
x1_clone = x1.clone()
x2_clone = x2.clone()
x1 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x1.numpy()))
x2 = torch.from_numpy(trans_np_hifuint8_tensor_to_float32(x2.numpy()))
scale = torch.randint(-5, 5, (3, 20), dtype=torch.float32)
pertoken_scale = torch.randint(-5, 5, (64, 3), dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
double_scale = scale * pertoken_scale
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * double_scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), x1_dtype=torch_npu.hifloat8, x2_dtype=torch_npu.hifloat8, output_dtype=torch.bfloat16)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_fp8_pertile(self):
x1 = torch.randint(-5, 5, (8192, 320), dtype=torch.float8_e5m2)
x2 = torch.randint(-5, 5, (320, 2560), dtype=torch.float8_e5m2)
x1_clone = x1.clone()
x2_clone = x2.clone()
scale = torch.randint(-5, 5, (3, 20), dtype=torch.float32)
pertoken_scale = torch.randint(-5, 5, (8192, 3), dtype=torch.float32)
scale_clone = scale_clone.clone()
pertoken_scale_clone = pertoken_scale.clone()
double_scale = scale * pertoken_scale
supported_output = torch.matmul(x1.to(torch.float32), x2.to(torch.float32)) * double_scale
custom_output = torch_npu.npu_quant_matmul(x1_clone.npu().t(), x2_clone.npu(), scale_clone.npu(), pertoken_scale=pertoken_scale_clone.npu(), output_dtype=torch.float32)
self.assertRtolEqual(supported_output.float().cpu().numpy(), custom_output.float().cpu().numpy(), 0.0001)
@SupportedDevices(['Ascend950'])
def test_npu_quant_matmul_fp8_pertile_invaild_trans(self):
x1 = torch.randint(-5, 5, (10, 256), dtype=torch.float8_e5m2)
x2 = torch.randint(-5, 5, (512, 256), dtype=torch.float8_e5m2)
scale = torch.randint(-5, 5, (4, 2), dtype=torch.float32)
pertoken_scale = torch.randint(-5, 5, (10, 2), dtype=torch.float32)
try:
custom_output = torch_npu.npu_quant_matmul(x1.npu(), x2.npu().t(), scale.npu(), pertoken_scale=pertoken_scale.npu(), output_dtype=torch.float32)
except RuntimeError as e:
self.assertTrue("transpose are not same, please check input" in str(e))
def a8w4_quant_golden(self, x, weight, scale, perTokenScale, groupList, bias, output_dtype):
m, k = x.shape
k, n = weight.shape
x = x.numpy()
weight = weight.cpu().numpy()
scale = scale.numpy()
perTokenScale = perTokenScale.numpy()
bias = bias.numpy()
bias = bias.reshape(1, -1)
weight_int8 = weight.astype(np.int8)
x = np.concatenate([x.reshape(m, 1, k) // 16, (x.reshape(m, 1, k) & 0x0F) - 8], axis=1).reshape(m * 2, k)
groupNum = 1
quantGroupNum = scale.shape[0]
index = np.cumsum(groupList)
xSplit = np.split(x, index * 2, axis=0)
xSplit[0] = x
perTokenScaleSplit = np.split(perTokenScale, index, axis=0)
weightGroup = weight_int8.reshape(groupNum, quantGroupNum, k // quantGroupNum, n).astype(np.int32)
mmOuts = []
atomic = np.float16
mmi = []
for i in range(groupNum):
xi = xSplit[i].reshape(-1, quantGroupNum, k // quantGroupNum).astype(np.int32)
mmi = np.zeros([xi.shape[0], n], dtype=atomic)
for j in range(quantGroupNum):
mm = np.matmul(xi[:, j, :], weightGroup[i, j, ...])
mm = mm.astype(np.float32) * scale[j, :].reshape(1, -1)
mmi = (mmi.astype(atomic) + mm.astype(atomic)).astype(atomic)
mmi = mmi.reshape(-1, 2, n).astype(np.float32)
mmi = mmi[:, 0, :] * 16 + mmi[:, 1, :] + bias[i].reshape(1, n)
mmi = mmi * perTokenScale
mmOuts.append(mmi)
golden = np.concatenate(mmOuts, axis=0)
golden_tensor = torch.from_numpy(mmi)
return golden_tensor.to(output_dtype)
def get_eb(self, golden: torch.Tensor, actual: torch.Tensor):
golden = golden.to(torch.float32)
golden_nmax = torch.clamp(torch.abs(golden), min=1)
actual_error = actual.to(torch.float32) - golden
EB = torch.mean(actual_error / golden_nmax)
if golden.dtype == torch.float16:
result = EB <= 2 ** (-10)
else:
result = EB <= 2 ** (-7)
return result
def ref_compare(self, golden: torch.Tensor, actual: torch.Tensor, thresh: float):
golden = golden.to(torch.float32)
golden_nmax = torch.clamp(torch.abs(golden), min=1)
abs_error = torch.abs(actual.to(torch.float32) - golden)
result = (abs_error <= thresh * golden_nmax).all()
return result
def golden_compare(self, out_tensor, golden_out_tensor, ksize):
eb = self.get_eb(golden_out_tensor, out_tensor)
if out_tensor.dtype == torch.float16:
cmp = self.ref_compare(golden_out_tensor, out_tensor, 2 ** -8 if ksize < 2048 else 2**-7)
else:
cmp = self.ref_compare(golden_out_tensor, out_tensor, 2 ** -7 if ksize < 2048 else 2**-6)
return eb and cmp
def trans_np_hifuint8_tensor_to_float32(self, in_tensor):
for i in range(in_tensor.shape[0]):
for j in range(in_tensor.shape[1]):
if in_tensor[i][j] == 0 or in_tensor[i][j] == 128:
in_tensor[i][j] = float(0)
elif in_tensor[i][j] == 239:
in_tensor[i][j] = -32768
elif in_tensor[i][j] == 111:
in_tensor[i][j] = 32768
else:
if in_tensor[i][j] >= 128:
sign = -1.0
else:
sign = 1.0
dot_4_bits = in_tensor[i][j] & 120
dot_4_value = dot_4_bits >> 3
if dot_4_value >= 12:
exponent = in_tensor[i][j] & 30
exponent_int = exponent >> 1
if exponent_int >= 8:
exponent_value = -exponent_int
else:
exponent_value = exponent_int + 8
fra_int = in_tensor[i][j] & 1
m_value = 1.0 + fra_int * 0.5
elif dot_4_value >= 8:
exponent = in_tensor[i][j] & 28
exponent_int = exponent >> 2
if exponent_int >= 4:
exponent_value = -exponent_int
else:
exponent_value = exponent_int + 4
fra_int = in_tensor[i][j] & 3
m_value = 1.0 + fra_int * 0.25
elif dot_4_value >= 4:
exponent = in_tensor[i][j] & 24
exponent_int = exponent >> 3
if exponent_int >= 2:
exponent_value = -exponent_int
else:
exponent_value = exponent_int + 2
fra_int = in_tensor[i][j] & 7
m_value = 1.0 + fra_int * 0.125
elif dot_4_value >= 2:
exponent = in_tensor[i][j] & 8
exponent_sign = exponent >> 3
if exponent_sign >= 1:
exponent_value = -1
else:
exponent_value = 1
fra_int = in_tensor[i][j] & 7
m_value = 1.0 + fra_int * 0.125
elif dot_4_value == 1:
exponent_value = 0
fra_int = in_tensor[i][j] & 7
m_value = 1.0 + fra_int * 0.125
elif dot_4_value == 0:
m_value = 1
exponent_value = (in_tensor[i][j] & 7) - 23
else:
print("error, dot error")
m_value = 0.0
exponent_value = 0.0
in_tensor[i][j] = sign + pow(2.0, exponent_value) * m_value
return in_tensor.astype(np.float32)
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_A8W4_float16(self):
M = 1
K = 8192
N = 128 * 8
group_size = 256
x1 = torch.randint(-128, 128, [M, K], dtype=torch.int8)
x2 = torch.randint(1, 8, (K, N), dtype=torch.int32)
y_offset = torch.randn([N, ], dtype=torch.float32)
x1_scale = torch.randn([M, 1], dtype=torch.float32) * 0.01
x2_scale = torch.randn([K // group_size, N], dtype=torch.float32).uniform_(0, 1)
group_size_list = [0, 0, group_size]
x1_npu = x1.npu()
x2_npu = torch_npu.npu_convert_weight_to_int4pack(x2.npu())
y_offset_npu = y_offset.npu()
x1_scale_npu = x1_scale.npu()
x2_scale_tmp = torch_npu.npu_trans_quant_param(x2_scale.npu().reshape([K // group_size * N, ]))
x2_scale_npu = x2_scale_tmp.reshape([K // group_size, N])
npu_out = torch_npu.npu_quant_matmul(x1_npu, x2_npu, scale=x2_scale_npu,
offset=y_offset_npu, pertoken_scale=x1_scale_npu,
bias=None, output_dtype=torch.float16, group_sizes=group_size_list)
cpu_out = self.a8w4_quant_golden(x=x1, weight=x2, scale=x2_scale, perTokenScale=x1_scale,
groupList=group_size_list, bias=y_offset, output_dtype=torch.float16)
self.assertTrue(self.golden_compare(npu_out.cpu(), cpu_out, K))
@SupportedDevices(['Ascend910B'])
def test_npu_quant_matmul_A8W4_bfloat16(self):
M = 1
K = 8192
N = 128 * 8
group_size = 256
x1 = torch.randint(-128, 128, [M, K], dtype=torch.int8)
x2 = torch.randint(1, 8, (K, N), dtype=torch.int32)
y_offset = torch.randn([N, ], dtype=torch.float32)
x1_scale = torch.randn([M, 1], dtype=torch.float32) * 0.01
x2_scale = torch.randn([K // group_size, N], dtype=torch.float32).uniform_(0, 1)
group_size_list = [0, 0, group_size]
x1_npu = x1.npu()
x2_npu = torch_npu.npu_convert_weight_to_int4pack(x2.npu())
y_offset_npu = y_offset.npu()
x1_scale_npu = x1_scale.npu()
x2_scale_tmp = torch_npu.npu_trans_quant_param(x2_scale.npu().reshape([K // group_size * N, ]))
x2_scale_npu = x2_scale_tmp.reshape([K // group_size, N])
npu_out = torch_npu.npu_quant_matmul(x1_npu, x2_npu, scale=x2_scale_npu,
offset=y_offset_npu, pertoken_scale=x1_scale_npu,
bias=None, output_dtype=torch.bfloat16, group_sizes=group_size_list)
cpu_out = self.a8w4_quant_golden(x=x1, weight=x2, scale=x2_scale, perTokenScale=x1_scale,
groupList=group_size_list, bias=y_offset, output_dtype=torch.bfloat16)
self.assertTrue(self.golden_compare(npu_out.cpu(), cpu_out, K))
if __name__ == "__main__":
run_tests()