import math
import unittest
import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestNPUWeightQuantBatchMatmul(TestCase):

    def supported_op_exec(self, x, weight, antiquant_scale, antiquant_offset=None):
        if antiquant_offset is not None:
            weight = weight + antiquant_offset
        res = torch.matmul(x, weight * antiquant_scale)
        return res

    def custom_op_exec(self, x, weight, antiquant_scale, antiquant_offset, weight_dtype=None):
        return torch_npu.npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, weight_dtype=weight_dtype)

    @SupportedDevices(['Ascend310P'])
    def test_npu_weight_quant_batchmatmul2(self, device="npu"):
        torch.manual_seed(0)
        x = torch.randn((4, 32, 1024, 128), dtype=torch.float16).npu()
        weight = torch.randn((4, 32, 128, 1024), dtype=torch.int8).npu()
        antiquant_scale = torch.randn((1, 1024), dtype=torch.float16).npu()
        antiquant_offset = torch.randn((1, 1024), dtype=torch.float16).npu()

        x_clone = x.clone()
        weight_clone = weight.clone()
        antiquant_scale_clone = antiquant_scale.clone()
        antiquant_offset_clone = antiquant_offset.clone()

        supported_output = self.supported_op_exec(
            x, weight, antiquant_scale, antiquant_offset)
        custom_output = self.custom_op_exec(
            x_clone, weight_clone, antiquant_scale_clone, antiquant_offset_clone)

        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @SupportedDevices(['Ascend950'])
    def test_npu_weight_quant_batchmatmul2_with_hifloat8(self, device="npu"):
        torch.manual_seed(0)
        x = torch.randn((96, 320), dtype=torch.float16).npu()
        weight = torch.randn((320, 256), dtype=torch.float32).npu()
        antiquant_scale = torch.randn((1, 256), dtype=torch.float16).npu()
        weight_hif8 = torch_npu.npu_dtype_cast(weight, torch_npu.hifloat8)

        x_clone = x.clone()
        weight_clone = weight.clone()
        weight_hif8_clone = weight_hif8.clone()
        antiquant_scale_clone = antiquant_scale.clone()

        supported_output = self.supported_op_exec(x, weight, antiquant_scale)
        custom_output = self.custom_op_exec(x_clone, weight_hif8_clone, antiquant_scale_clone, None, torch_npu.hifloat8)

        self.assertRtolEqual(supported_output, custom_output, 0.001)

    @SupportedDevices(['Ascend950'])
    def test_npu_weight_quant_batchmatmul2_with_A16W4_nz_perchannel(self, device="npu"):
        torch.manual_seed(0)
        m = 1
        k = 128
        n = 256
        group_size = 64
        cpu_x = torch.randn((m, k), dtype=torch.float16)
        cpu_weight = torch.randint(low=3, high=4, size=(k, n), dtype=torch.int32)
        cpu_antiquant_scale = torch.randn((1, 256), dtype=torch.float16)

        npu_x = cpu_x.clone().npu()
        npu_weight = cpu_weight.clone().npu()
        npu_weight = torch_npu.npu_format_cast(cpu_weight.npu(), 29, customize_dtype=cpu_x.dtype)
        npu_weight = torch_npu.npu_convert_weight_to_int4pack(npu_weight)
        npu_antiquant_scale = cpu_antiquant_scale.clone().npu()

        supported_output = self.supported_op_exec(cpu_x, cpu_weight, cpu_antiquant_scale)
        custom_output = self.custom_op_exec(npu_x, npu_weight, npu_antiquant_scale, None, None)

        self.assertRtolEqual(supported_output, custom_output, 0.001)


if __name__ == "__main__":
    run_tests()