import torch
import numpy as np
import torch_npu

from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import get_npu_device, SupportedDevices


class DataInfo(object):
    def __init__(self, min_d, max_d, shape_x, shape_scale, shape_group_index, dtype_x, dtype_scale):
        self.min_d = min_d
        self.max_d = max_d
        self.shape_x = shape_x
        self.shape_scale = shape_scale
        self.shape_group_index = shape_group_index
        self.dtype_x = dtype_x
        self.dtype_scale = dtype_scale


class TestNPUGroupQuant(TestCase):

    def generate_data_npu_quantize(self, datainfo):
        input_x = np.random.uniform(datainfo.min_d, datainfo.max_d, datainfo.shape_x).astype(datainfo.dtype_x)
        scale = np.random.uniform(datainfo.min_d, datainfo.max_d, datainfo.shape_scale).astype(datainfo.dtype_scale)
        offset = np.random.uniform(datainfo.min_d, datainfo.max_d, (1,)).astype(datainfo.dtype_scale)

        S = datainfo.shape_x[0]
        E = datainfo.shape_group_index[0]
        group_index = np.random.uniform(0, S, E - 1).astype('int32')
        group_index = np.sort(group_index)
        group_index = np.append(group_index, S)

        npu_input_x = torch.from_numpy(input_x)
        npu_input_scale = torch.from_numpy(scale)
        npu_input_group_index = torch.from_numpy(group_index)
        npu_input_offset = torch.from_numpy(offset)

        return [npu_input_x, npu_input_scale, npu_input_group_index, npu_input_offset]

    def cpu_op_exec_group_quant(self, input_x, input_scale, input_group_index, input_offset, dtype):
        S = input_x.shape[0]
        H = input_x.shape[1]
        E = input_scale.shape[0]

        input_x = input_x.astype("float32")
        input_scale = input_scale.astype("float32")
        input_offset = input_offset.astype("float32")
        y = np.empty(shape=(0, H), dtype='float32')

        for row_scale in range(E):
            x_start_row = 0 if row_scale == 0 else input_group_index[row_scale - 1]
            x_end_row = input_group_index[row_scale]
            if x_start_row < x_end_row:
                y_rows = input_x[x_start_row:x_end_row] * input_scale[row_scale] + input_offset
                y = np.concatenate([y, y_rows], axis=0)

        y = np.round(y, 0)
        y = np.clip(y, -128, 127).astype("int8")
        return y

    def npu_op_exec_group_quant(self, input_x, input_scale, input_group_index, input_offset, dtype):
        input_x = input_x.to("npu")
        input_scale = input_scale.to("npu")
        input_group_index = input_group_index.to("npu")
        input_offset = input_offset.to("npu")
        output = torch_npu.npu_group_quant(input_x, input_scale, input_group_index, offset=input_offset, dst_dtype=dtype)
        output = output.to("cpu")
        output = output.numpy()
        return output

    @SupportedDevices(['Ascend910B'])
    def test_npu_group_quant(self):
        datainfo = DataInfo(-1, 1, (16, 128), (5, 128), (5,), np.float32, np.float32)
        x, scale, group_index, offset = self.generate_data_npu_quantize(datainfo)
        cpu_output1 = self.cpu_op_exec_group_quant(x.numpy(), scale.numpy(), group_index.numpy(), offset.numpy(), torch.qint8)
        npu_output1 = self.npu_op_exec_group_quant(x, scale, group_index, offset, torch.qint8)
        self.assertRtolEqual(cpu_output1, npu_output1)


if __name__ == "__main__":
    run_tests()