import math
import unittest
import copy
import struct
from struct import pack, unpack
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
from torch.testing import assert_close
class TestDynamicQuant(TestCase):
def supported_op_exec(self, input_tensor, smooth_scales=None, group_index=None, dst_type=torch.int8):
input_tensor = input_tensor.float()
if group_index is not None:
input_shape = input_tensor.shape
row_num = input_tensor.numel() // input_shape[-1]
input_tensor = input_tensor.view(row_num, input_shape[-1])
start = 0
group_num = group_index.numel()
for index in range(group_num):
end = group_index[index]
if end <= start:
start = end
continue
smooth_scales_row = smooth_scales[index]
input_tensor[start:end] = input_tensor[start:end] * smooth_scales_row
start = end
input_tensor = input_tensor.view(input_shape)
elif smooth_scales is not None:
input_tensor = input_tensor * smooth_scales
input_abs = input_tensor.abs()
input_max = input_abs.max(dim=-1, keepdim=True)[0]
if dst_type == torch.int8:
scale = input_max / 127
else:
scale = input_max / 7
try:
input_tensor = input_tensor / scale
except ZeroDivisionError as err:
raise err
output = input_tensor.round()
if dst_type == torch.int8:
output = output.to(torch.int8)
else:
output = output.to(torch.int32)
return [output, scale.squeeze(-1).float()]
def custom_op_exec(self, input_tensor, smooth_scales=None, group_index=None, dst_type=torch.int8):
return torch_npu.npu_dynamic_quant(input_tensor, smooth_scales=smooth_scales, group_index=group_index, dst_type=dst_type)
def dynamic_quant_perchannel_int8(self, input_tensor, smooth_scales=None):
input_tensor_fp32 = input_tensor.to(torch.float32)
input_smooth_fp32 = smooth_scales.to(torch.float32) if smooth_scales is not None else None
input_scaled_tensor = input_tensor_fp32 * input_smooth_fp32 if input_smooth_fp32 is not None else input_tensor_fp32
scale_max = 127.0
input_abs = torch.abs(input_scaled_tensor)
input_max = torch.max(input_abs, dim=-2, keepdims=True)[0]
scale = input_max * (1.0 / scale_max)
input_scaled_tensor = input_scaled_tensor / scale
round_data = torch.round(input_scaled_tensor, decimals=0).to(torch.int8)
return [round_data, scale.squeeze(-2)]
def dynamic_quant_perchannel_impl(self, input_tensor, smooth_scales=None, group_index=None, dst_type=torch.int8, quant_mode="perchannel"):
return torch_npu.npu_dynamic_quant(input_tensor, smooth_scales=smooth_scales, group_index=group_index, dst_type=dst_type, quant_mode=quant_mode)
def generate_input(self, input_shape, dtype="float16", use_smooth=False, group_num=1):
date_type = torch.float16 if dtype == "float16" else torch.bfloat16
input_tensor = torch.randn(input_shape, dtype=date_type)
group_index = None
smooth_scales = None
if group_num > 1:
smooth_scales = torch.randn(group_num, input_shape[-1], dtype=date_type)
row_num = input_tensor.numel() // input_tensor.shape[-1]
group_index_list = []
for _ in range(group_num):
group_index_list.append(np.random.randint(0, row_num))
group_index_list = sorted(group_index_list)
group_index_list[-1] = row_num
group_index = torch.tensor(group_index_list).to(torch.int32)
else:
smooth_scales = torch.randn(input_shape[-1], dtype=date_type)
return input_tensor, smooth_scales, group_index
def generate_input_perchannel(self, input_shape, dtype="float16"):
date_type = torch.float16 if dtype == "float16" else torch.bfloat16
input_tensor = torch.randn(input_shape, dtype=date_type)
smooth_scales = torch.randn((input_shape[-2], 1), dtype=date_type)
return input_tensor, smooth_scales
def convert_int4_to_int8(self, x):
x_uint8 = x.view(torch.uint8).view(-1, 1)
x_uint8_left = ((x_uint8 & 0xF0).view(torch.int8) >> 4)
x_uint8_right = ((x_uint8 & 0x0F) << 4).view(torch.int8) >> 4
x_int4 = torch.cat([x_uint8_right, x_uint8_left], dim=-1).contiguous()
return x_int4
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_fp16_input(self, device="npu"):
input_tensor, _, _ = self.generate_input(input_shape=[2, 32, 256],
dtype="float16",
use_smooth=False,
group_num=1)
input_tensor = input_tensor.to(device)
supported_output = self.supported_op_exec(input_tensor.clone())
custom_output = self.custom_op_exec(input_tensor.clone())
assert_close(supported_output[0], custom_output[0], atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_bf16_input(self, device="npu"):
input_tensor, _, _ = self.generate_input(input_shape=[2, 32, 256],
dtype="bfloat16",
use_smooth=False,
group_num=1)
input_tensor = input_tensor.to(device)
supported_output = self.supported_op_exec(input_tensor.clone())
custom_output = self.custom_op_exec(input_tensor.clone())
assert_close(supported_output[0], custom_output[0], atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_fp16_input_smooth_group(self, device="npu"):
input_tensor, smooth_scales, group_index = self.generate_input(input_shape=[2, 32, 256],
dtype="float16",
use_smooth=True,
group_num=64)
input_tensor, smooth_scales, group_index = input_tensor.to(device), smooth_scales.to(device), group_index.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone())
custom_output = self.custom_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), torch.int8)
assert_close(supported_output[0], custom_output[0], atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_bfp16_input_smooth_group(self, device="npu"):
input_tensor, smooth_scales, group_index = self.generate_input(input_shape=[2, 32, 256],
dtype="bfloat16",
use_smooth=True,
group_num=64)
input_tensor, smooth_scales, group_index = input_tensor.to(device), smooth_scales.to(device), group_index.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone())
custom_output = self.custom_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), torch.int8)
assert_close(supported_output[0], custom_output[0], atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_int4_fp16_input(self, device="npu"):
input_tensor, _, _ = self.generate_input(input_shape=[2, 32, 256],
dtype="float16",
use_smooth=False,
group_num=1)
input_tensor = input_tensor.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), dst_type=torch.quint4x2)
custom_output = self.custom_op_exec(input_tensor.clone(), dst_type=torch.quint4x2)
y = self.convert_int4_to_int8(custom_output[0]).view([2, 32, 256]).to(torch.int32)
assert_close(supported_output[0], y, atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_int4_bf16_input(self, device="npu"):
input_tensor, _, _ = self.generate_input(input_shape=[2, 32, 256],
dtype="bfloat16",
use_smooth=False,
group_num=1)
input_tensor = input_tensor.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), dst_type=torch.quint4x2)
custom_output = self.custom_op_exec(input_tensor.clone(), dst_type=torch.quint4x2)
y = self.convert_int4_to_int8(custom_output[0]).view([2, 32, 256]).to(torch.int32)
assert_close(supported_output[0], y, atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_int4_fp16_input_smooth_group(self, device="npu"):
input_tensor, smooth_scales, group_index = self.generate_input(input_shape=[2, 32, 256],
dtype="float16",
use_smooth=True,
group_num=64)
input_tensor, smooth_scales, group_index = input_tensor.to(device), smooth_scales.to(device), group_index.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), dst_type=torch.quint4x2)
custom_output = self.custom_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), torch.quint4x2)
y = self.convert_int4_to_int8(custom_output[0]).view([2, 32, 256]).to(torch.int32)
assert_close(supported_output[0], y, atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend910B'])
def test_npu_dynamic_quant_int4_bf16_input_smooth_group(self, device="npu"):
input_tensor, smooth_scales, group_index = self.generate_input(input_shape=[2, 32, 256],
dtype="bfloat16",
use_smooth=True,
group_num=64)
input_tensor, smooth_scales, group_index = input_tensor.to(device), smooth_scales.to(device), group_index.to(device)
supported_output = self.supported_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), dst_type=torch.quint4x2)
custom_output = self.custom_op_exec(input_tensor.clone(), smooth_scales.clone(), group_index.clone(), torch.quint4x2)
y = self.convert_int4_to_int8(custom_output[0]).view([2, 32, 256]).to(torch.int32)
assert_close(supported_output[0], y, atol=1.0, rtol=0.0)
assert_close(supported_output[1], custom_output[1], atol=0.0, rtol=0.0001)
@SupportedDevices(['Ascend950'])
def test_npu_dynamic_quant_int8_bf16_input_smooth_perchannel(self, device="npu"):
input_tensor, smooth_scales = self.generate_input_perchannel(input_shape=[2, 32, 256], dtype="bfloat16")
input_tensor, smooth_scales = input_tensor.to(device), smooth_scales.to(device)
smooth_scales_copy = smooth_scales.squeeze().to(device)
supported_output = self.dynamic_quant_perchannel_int8(input_tensor.clone(), smooth_scales.clone())
custom_output = self.dynamic_quant_perchannel_impl(input_tensor.clone(), smooth_scales_copy.clone(), None, torch.int8, "perchannel")
y = custom_output[0].view([2, 32, 256])
assert_close(supported_output[0], y, atol=0.1, rtol=0.01)
assert_close(supported_output[1], custom_output[1], atol=0.01, rtol=0.001)
if __name__ == "__main__":
run_tests()