import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.contrib.module import LinearWeightQuant
from torch_npu.testing.common_utils import SupportedDevices
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
def f32_2_s9(array):
array_round = np.round(array)
array_round_clip = np.clip(array_round, -256, 255)
return array_round_clip
class TestLinearWeightQuant(TestCase):
def cpu_linear_weight_quant(self, weight_cpu, x_cpu, antiquant_scale_cpu, antiquant_offset_cpu):
x_cpu = x_cpu.to(torch.float32)
if antiquant_offset_cpu is not None:
weight_cpu = weight_cpu + antiquant_offset_cpu
antiquant_weight = weight_cpu * antiquant_scale_cpu
antiquant_weight = antiquant_weight.to(torch.float32)
cpu_out = torch.matmul(x_cpu, antiquant_weight).numpy()
cpu_out = cpu_out.astype("float16")
return cpu_out
def npu_linear_weight_quant(self, in_features, out_features, antiquant_scale, weight, x, antiquant_offset=None,
weight_dtype=None):
model = LinearWeightQuant(in_features,
out_features,
bias=False,
device=torch.device(f'npu:0'),
dtype=x.dtype,
antiquant_offset=True,
quant_scale=False,
quant_offset=False,
antiquant_group_size=0,
weight_dtype=weight_dtype
)
model = model.npu()
model.weight.data = weight
model.antiquant_scale.data = antiquant_scale
model.antiquant_offset.data = antiquant_offset
npu_out = model(x)
return npu_out
@unittest.skipIf(DEVICE_NAME != 'Ascend910B',
"OP `WeightQuantBatchMatmulV2` is only supported on 910B, skip this ut for this device type!")
def test_npu_linear_weight_quant(self):
m = 1024
k = 11264
n = 1664
x_cpu = torch.randn((m, k), dtype=torch.float16)
weight_cpu = torch.randn((n, k), dtype=torch.float16)
weight_cpu = weight_cpu.to(torch.int8)
weight_cpu = weight_cpu.to(torch.float16)
weight_cpu_trans = weight_cpu.transpose(0, 1)
antiquant_scale_cpu = torch.randn((n), dtype=torch.float16)
antiquant_offset_cpu = torch.randn((n), dtype=torch.float16)
x_npu = x_cpu.npu()
weight_npu = weight_cpu.to(torch.int8).npu()
antiquant_scale_npu = antiquant_scale_cpu.npu()
antiquant_offset_npu = antiquant_offset_cpu.npu()
npu_out = self.npu_linear_weight_quant(k, n, antiquant_scale_npu, weight_npu, x_npu, antiquant_offset_npu)
cpu_out = self.cpu_linear_weight_quant(weight_cpu_trans, x_cpu, antiquant_scale_cpu, antiquant_offset_cpu)
npu_out = npu_out.cpu()
self.assertRtolEqual(cpu_out, npu_out.numpy(), 0.01)
@SupportedDevices(['Ascend950'])
def test_npu_linear_weight_quant_weight_dtype_hif8(self):
m = 2
k = 64
n = 128
x_npu = torch.randn((m, k), dtype=torch.float16).npu()
weight_npu = torch.randint(0, 255, (k, n), dtype=torch.int8).npu()
antiquant_scale_npu = torch.randn((n), dtype=torch.float16).npu()
npu_out = self.npu_linear_weight_quant(k, n, antiquant_scale_npu, weight_npu, x_npu, None,
weight_dtype=torch_npu.hifloat8)
supported_output = torch_npu.npu_weight_quant_batchmatmul(x_npu, weight_npu, antiquant_scale_npu, None, None,
None, None, 0, weight_dtype=torch_npu.hifloat8)
npu_out = npu_out.cpu()
self.assertRtolEqual(supported_output, npu_out.numpy(), 0.01)
if __name__ == "__main__":
run_tests()