import unittest
import numpy as np
import torch

import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from ml_dtypes import int4


def unpack_int4(s32arr):
    dst_shape = s32arr.numpy().shape
    if len(dst_shape) == 0:
        dst_shape = (8, )
    else:
        dst_shape = (*(dst_shape[:-1]), dst_shape[-1] * 8)

    sa1 = s32arr.numpy().astype(np.int32)
    sa2 = sa1.tobytes()
    sa3 = np.frombuffer(sa2, dtype=np.uint8)
    shift = np.array([0, 4], dtype=np.uint8)
    sa4 = np.bitwise_and(sa3.reshape([-1, 1]) >> shift, 0b00001111).astype(int4).astype(np.int8).reshape(dst_shape)
    return torch.from_numpy(sa4)


class TestAntiQuant(TestCase):

    def custom_op_exec(self, input_x, scale, offset, dst_dtype, src_dtype):
        if input_x.dtype == torch.int32:
            input_x = unpack_int4(input_x)
        scale = torch.broadcast_to(scale, input_x.shape)
        if offset is None:
            offset = torch.zeros_like(scale)

        x = input_x.to(torch.float32)

        offset_temp = x + offset
        output = offset_temp * scale

        output = output.to(dst_dtype)
        return output.cpu().detach()

    def npu_op_exec(self, input_x, scale, offset, dst_dtype, src_dtype):
        output = torch_npu.npu_anti_quant(input_x, scale, offset=offset, dst_dtype=dst_dtype, src_dtype=src_dtype)
        return output.cpu().detach()

    @SupportedDevices(['Ascend910B'])
    def test_anti_quant(self, device="npu"):
        shape_format = [
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], None, torch.float16, None],
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], [np.float32, -1, [100]], torch.float16, None],
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], None, torch.float16, torch.int8],
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], None, torch.float16, None],
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], None, torch.bfloat16, None],
            [[np.int8, -1, [10, 100]], [np.float32, -1, [100]], [np.float32, -1, [100]], torch.bfloat16, torch.int8],
            [[np.int32, -1, [10, 100]], [np.float32, -1, [800]], None, torch.float16, None],
            [[np.int32, -1, [10, 100]], [np.float32, -1, [800]], [np.float32, -1, [800]], torch.float16, None],
            [[np.int32, -1, [10, 100]], [np.float32, -1, [800]], None, torch.float16, torch.quint4x2],
            [[np.int32, -1, [10, 100]], [np.float32, -1, [800]], None, torch.bfloat16, torch.quint4x2],
        ]

        for item in shape_format:
            cpu_input_x, npu_input_x = create_common_tensor(item[0], -127, 127)
            cpu_scale, npu_scale = create_common_tensor(item[1], -100, 100)
            cpu_offset, npu_offset = (None, None) if item[2] is None else create_common_tensor(item[2], -100, 100)

            npu_output = self.npu_op_exec(npu_input_x, npu_scale, npu_offset, *item[3:])
            custom_output = self.custom_op_exec(cpu_input_x, cpu_scale, cpu_offset, *item[3:])

            if item[3] == torch.bfloat16:
                npu_output = npu_output.to(torch.float32)
                custom_output = custom_output.to(torch.float32)
            self.assertRtolEqual(npu_output, custom_output)


    @SupportedDevices(['Ascend910B'])
    def test_anti_quant_invalid_dtype(self, device="npu"):
        npu_input_x = torch.tensor([1,2,3,4], dtype=torch.int8).npu()
        npu_scale = torch.tensor([2,0], dtype=torch.float32).npu()
        npu_offset = torch.tensor([2], dtype=torch.int32).npu()

        invalid_dst_type = -10
        try:
            out_npu = torch_npu.npu_anti_quant(
                npu_input_x,
                npu_scale,
                offset = npu_offset,
                dst_type = invalid_dst_type,
                src_dtype = torch.int8
            )
            self.fail("Do not throw the expected exception")
        except Exception as e:
            excepted_error = "Input dst_type must be valid, but got UNKNOWN_SCALAR"
            self.assertTrue(excepted_error in str(e), f"err message mismatch, exception:{e}")


if __name__ == "__main__":
    run_tests()