import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
class TestNPUConvertWeightToINT4Pack(TestCase):
def supported_op_exec(self, x, weight, antiquant_scale, antiquant_offset):
x = x.to(torch.float32)
if antiquant_offset != None:
weight = weight.to(torch.float16) + antiquant_offset
res = torch.matmul(x, (weight * antiquant_scale).to(torch.float32))
return res
def custom_op_exec(self, x, weight, antiquant_scale, antiquant_offset, antiquant_group_size=0):
return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset,
antiquant_group_size=antiquant_group_size)
def pack_int4_3dim(self, weight_unpack, g, k, n):
weight_packed = np.zeros((g, k, n//8), dtype=np.int32)
for gid in range(g):
for kid in range(k):
for nid in range(n // 8):
w0 = weight_unpack[gid, kid, nid * 8]
w1 = weight_unpack[gid, kid, nid * 8 + 1]
w2 = weight_unpack[gid, kid, nid * 8 + 2]
w3 = weight_unpack[gid, kid, nid * 8 + 3]
w4 = weight_unpack[gid, kid, nid * 8 + 4]
w5 = weight_unpack[gid, kid, nid * 8 + 5]
w6 = weight_unpack[gid, kid, nid * 8 + 6]
w7 = weight_unpack[gid, kid, nid * 8 + 7]
w8_0 = (w0 & 0x0f) + ((w1 & 0x0f) << 4)
w8_1 = (w2 & 0x0f) + ((w3 & 0x0f) << 4)
w8_2 = (w4 & 0x0f) + ((w5 & 0x0f) << 4)
w8_3 = (w6 & 0x0f) + ((w7 & 0x0f) << 4)
weight_packed[gid, kid, nid] = w8_0 + (w8_1 << 8) + (w8_2 << 16) + (w8_3 << 24)
return torch.from_numpy(weight_packed).to(torch.int32)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_convert_weight_to_int4pack(self, device="npu"):
torch.manual_seed(0)
m = 128
k = 64
n = 32
trans_weight = False
cpu_x = torch.randn((m, k), dtype=torch.float16)
if trans_weight:
cpu_weight = torch.randint(low=-8, high=8, size=(n, k), dtype=torch.int32)
cpu_antiquantscale = torch.randn((n, 1), dtype=torch.float16)
cpu_antiquantoffset = torch.randn((n, 1), dtype=torch.float16)
else:
cpu_weight = torch.randint(low=-8, high=8, size=(k, n), dtype=torch.int32)
cpu_antiquantscale = torch.randn((1, n), dtype=torch.float16)
cpu_antiquantoffset = torch.randn((1, n), dtype=torch.float16)
weight_int4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu())
if trans_weight:
cpu_weight = cpu_weight.transpose(-1, -2)
weight_int4 = weight_int4.transpose(-1, -2)
cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
cpu_antiquantoffset = cpu_antiquantoffset.transpose(-1, -2)
supported_output = self.supported_op_exec(
cpu_x, cpu_weight, cpu_antiquantscale, cpu_antiquantoffset)
custom_output = self.custom_op_exec(
cpu_x.npu(), weight_int4.npu(), cpu_antiquantscale.npu(), cpu_antiquantoffset.npu())
self.assertRtolEqual(supported_output.to(torch.float16), custom_output, 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_convert_weight_to_int4pack(self, device="npu"):
torch.manual_seed(0)
m = 128
k = 64
n = 64
group_size = 32
trans_weight = False
cpu_x = torch.randn((m, k), dtype=torch.float16)
if trans_weight:
cpu_weight = torch.randint(low=-3, high=3, size=(n, k), dtype=torch.float32)
cpu_antiquantscale = torch.randint(low=124, high=130, size=(n, k//group_size), dtype=torch.uint8)
else:
cpu_weight = torch.randint(low=-3, high=3, size=(k, n), dtype=torch.float32)
cpu_antiquantscale = torch.randint(low=124, high=130, size=(k//group_size, n), dtype=torch.uint8)
weight_fp4 = torch_npu.npu_convert_weight_to_int4pack(cpu_weight.npu())
cpu_antiquantscale_cpu = (2 ** (cpu_antiquantscale.to(torch.float64) - 127))
cpu_antiquantscale_cpu = torch.repeat_interleave(cpu_antiquantscale_cpu, group_size, dim=0)
if trans_weight:
cpu_weight = cpu_weight.transpose(-1, -2)
weight_fp4 = weight_fp4.transpose(-1, -2)
cpu_antiquantscale = cpu_antiquantscale.transpose(-1, -2)
supported_output = self.supported_op_exec(
cpu_x, cpu_weight, cpu_antiquantscale_cpu.to(torch.float32), None)
custom_output = self.custom_op_exec(
cpu_x.npu(), weight_fp4.npu(), cpu_antiquantscale.npu(), None, antiquant_group_size=group_size)
self.assertRtolEqual(supported_output.to(torch.float16), custom_output, 0.001)
@SupportedDevices(['Ascend950'])
def test_npu_convert_weight_to_int4pack_3dim_nd(self, device="npu"):
g = 2
k = 8
n = 16
out_shape = [g, k, n/8]
weight_unpk = np.random.randint(-8, 7, (g, k, n), dtype=np.int32)
weight_unpk = torch.from_numpy(weight_unpk).to(torch.int32)
weight_npu_packed = torch_npu.npu_convert_weight_to_int4pack(weight_unpk.npu())
weight_cpu_packed = self.pack_int4_3dim(weight_unpk, g, k, n)
self.assertRtolEqual(list(weight_npu_packed.shape), out_shape, 0.001)
self.assertRtolEqual(weight_npu_packed, weight_cpu_packed, 0.001)
if __name__ == "__main__":
run_tests()