import unittest
import torch
import numpy as np
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices


class DataInfo(object):
    def __init__(self, min_d, max_d, shape_x, shape_p1, shape_p2, dtype):
        self.min_d = min_d
        self.max_d = max_d
        self.shape_x = shape_x
        self.shape_p1 = shape_p1
        self.shape_p2 = shape_p2
        self.dtype = dtype


class TestNPUFlatQuant(TestCase):

    def generate_data_npu_quantize(self, datainfo):
        input_x = np.random.uniform(datainfo.min_d, datainfo.max_d, datainfo.shape_x).astype(np.float32)
        input_p1 = np.random.uniform(datainfo.min_d, datainfo.max_d, datainfo.shape_p1).astype(np.float32)
        input_p2 = np.random.uniform(datainfo.min_d, datainfo.max_d, datainfo.shape_p2).astype(np.float32)

        npu_input_x = torch.from_numpy(input_x).to(dtype = datainfo.dtype)
        npu_input_p1 = torch.from_numpy(input_p1).to(dtype = datainfo.dtype)
        npu_input_p2 = torch.from_numpy(input_p2).to(dtype = datainfo.dtype)

        return [npu_input_x.to("npu"), npu_input_p1.to("npu"), npu_input_p2.to("npu")]

    def golden_op_exec_kronecker_quant(self, input_x, input_p1, input_p2, clip_ratio, dst_type_max):
        if(clip_ratio is None):
            clip_ratio = 1.0
        if(dst_type_max is None):
            dst_type_max = 0.0
        K, M, N = input_x.shape
        x1 = input_x @ input_p2
        x2 = input_p1 @ x1
        x2 = x2.flatten(-2, -1)
        qscale = torch.abs(x2).max(dim=-1, keepdim=True)[0].to(torch.float)
        ratio = torch.ones_like(qscale) * 7 * clip_ratio
        qscale2 = ratio / qscale
        golden_out = (x2.to(torch.float) * qscale2).to(torch.half).to(torch.int8).reshape(K, M, N)
        golden_scale = torch.flatten(qscale / ratio).reshape(K)
        return golden_out.to("cpu").numpy(), golden_scale.to("cpu").numpy()

    def tensor_int32_to_int8(self, tensor_int32):
        K, M, N = tensor_int32.shape
        tensor_array = tensor_int32.reshape(1, tensor_int32.numel())[0]
        int32_array = tensor_array.view(torch.int32).cpu().numpy()
        masks = np.array([0xF << (i * 4) for i in range(8)], dtype=np.uint32)
        shifted = (int32_array[:, None] & masks) >> np.arange(0, 32, 4)
        sign_extended = np.where(shifted & 0x8, shifted - 16, shifted)
        return torch.tensor(sign_extended.astype(np.int32)).to(torch.int8).reshape(K, M, N * 8)

    def npu_op_exec_kronecker_quant(self, input_x, input_p1, input_p2):
        out, quantScale = torch_npu.npu_kronecker_quant(input_x, input_p1, input_p2)
        return self.tensor_int32_to_int8(out.to("cpu")).numpy(), quantScale.to("cpu").numpy()

    def npu_op_exec_kronecker_quant_ratio(self, input_x, input_p1, input_p2, clip_ratio):
        out, quantScale = torch_npu.npu_kronecker_quant(input_x, input_p1, input_p2, clip_ratio)
        return self.tensor_int32_to_int8(out.to("cpu")).numpy(), quantScale.to("cpu").numpy()

    def npu_op_exec_kronecker_quant_dst_type_max(self, input_x, input_p1, input_p2, clip_ratio, dst_type_max):
        out, quantScale = torch_npu.npu_kronecker_quant(input_x, input_p1, input_p2, clip_ratio, dst_type_max=dst_type_max)
        return self.tensor_int32_to_int8(out.to("cpu")).numpy(), quantScale.to("cpu").numpy()

    @unittest.skip("skip test_npu_kronecker_quant_float16 now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_float16(self):
        datainfo = DataInfo(1, 1, (16, 7, 16), (7, 7), (16, 16), torch.float16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, None, None)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant(x, p1, p2)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

    @unittest.skip("skip test_npu_kronecker_quant_bfloat16 now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_bfloat16(self):
        datainfo = DataInfo(1, 1, (16, 56, 16), (56, 56), (16, 16), torch.bfloat16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, None, None)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant(x, p1, p2)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

    @unittest.skip("skip test_npu_kronecker_quant_float16_ratio now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_float16_ratio(self):
        datainfo = DataInfo(1, 1, (16, 8, 32), (8, 8), (32, 32), torch.float16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, 0.9063, None)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant_ratio(x, p1, p2, 0.9063)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

    @unittest.skip("skip test_npu_kronecker_quant_bfloat16_ratio now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_bfloat16_ratio(self):
        datainfo = DataInfo(1, 1, (16, 3, 64), (3, 3), (64, 64), torch.bfloat16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, 0.7848, 0.0)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant_ratio(x, p1, p2, 0.7848)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

    @unittest.skip("skip test_npu_kronecker_quant_float16_dst_type_max now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_float16_dst_type_max(self):
        datainfo = DataInfo(1, 1, (16, 8, 32), (8, 8), (32, 32), torch.float16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, 1.0, 0.0)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant_dst_type_max(x, p1, p2, 1.0, 0.0)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

    @unittest.skip("skip test_npu_kronecker_quant_bfloat16_dst_type_max now")
    @SupportedDevices(['Ascend910B'])
    def test_npu_kronecker_quant_bfloat16_dst_type_max(self):
        datainfo = DataInfo(1, 1, (16, 3, 64), (3, 3), (64, 64), torch.bfloat16)
        x, p1, p2 = self.generate_data_npu_quantize(datainfo)
        golden_out, golden_scale = self.golden_op_exec_kronecker_quant(x, p1, p2, 1.0, 0.0)
        npu_out, npu_scale = self.npu_op_exec_kronecker_quant_dst_type_max(x, p1, p2, 1.0, 0.0)
        self.assertRtolEqual(golden_out, npu_out)
        self.assertRtolEqual(golden_scale, npu_scale)

if __name__ == "__main__":
    run_tests()