import math
import unittest
import copy
import struct
from struct import pack, unpack
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
from torch.testing import assert_close
class TestGroupedDynamicBlockQuant(TestCase):
def custom_op_exec(self, input_tensor, group_list_tensor, min_scale=0.0, round_mode="rint", dst_type=291, row_block_size=1, col_block_size=128, group_list_type=0):
return torch_npu.npu_grouped_dynamic_block_quant(input_tensor,
group_list_tensor,
min_scale=min_scale,
round_mode=round_mode,
dst_type=dst_type,
row_block_size=row_block_size,
col_block_size=col_block_size,
group_list_type=group_list_type)
def supported_op_exec(self, input_tensor):
if torch.all(torch.eq(input_tensor, 0.0)) and input_tensor.shape == torch.Size([1, 2]):
device = input_tensor.device
y = torch.tensor([[0, 0]], dtype=torch.float8_e5m2, device=device)
scale = torch.tensor([[0.0], [0.0]], dtype=torch.float, device=device)
return y, scale
def generate_input(self, input, group_list, input_dtype="float16"):
input_data_type = torch.float16 if input_dtype == "float16" else torch.bfloat16
input_value = 0.0
input_tensor = torch.full(input, input_value, dtype=input_data_type)
group_list_data_type = torch.int32
group_list_value = 1
group_list_tensor = torch.full(group_list, group_list_value, dtype=group_list_data_type)
return input_tensor, group_list_tensor
@SupportedDevices(['Ascend950'])
def test_npu_grouped_dynamic_block_quant(self, device="npu"):
input_tensor, group_list_tensor = self.generate_input(input=[1, 2], group_list=[1], input_dtype="float16")
input_tensor = input_tensor.to(device)
group_list_tensor = group_list_tensor.to(device)
supported_output = self.supported_op_exec(input_tensor.clone())
custom_output = self.custom_op_exec(input_tensor.clone(), group_list_tensor.clone(), 0.0, "rint", 291, 1, 128, 0)
y = custom_output[0].view([1, 2]).view(torch.uint8)
scale = custom_output[1].view([2, 1])
assert torch.all(y == supported_output[0].view(torch.uint8))
assert_close(supported_output[1], scale, atol=0.01, rtol=0.001)
if __name__ == "__main__":
run_tests()