import os
import shutil
import unittest

import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices


class TestNPUSwigluQuantV2(TestCase):
    def golden_swiglu_quant_torch(
        self,
        x,
        smooth_scales,
        offsets,
        group_index,
        activate_left,
        quant_mode,
        group_list_type,
        dst_type
    ):
        x1, x2 = torch.chunk(x, 2, dim=-1)
        y = torch.nn.functional.silu(x1) * x2 if activate_left else x1 * torch.nn.functional.silu(x2)
        dst_type_scale = 127.0 if dst_type in (torch.int8, None) else 7.0

        if group_index is not None:
            begin_index = 0
            for i in range(group_index.shape[0]):
                end_index = group_index[i] if group_list_type == 0 else begin_index + group_index[i]
                y_slice = y[begin_index:end_index]
                scale_slice = smooth_scales[i]
                offset_slice = offsets[i] if offsets is not None else 0
                y[begin_index:end_index] = y_slice * scale_slice + (offset_slice if quant_mode == 0 else 0)
                begin_index = end_index
        else:
            y = y * smooth_scales + (offsets if quant_mode == 0 else 0)

        scale = None
        if quant_mode == 1:
            scale = dst_type_scale / torch.max(torch.abs(y), dim=1)[0]
            y = y * scale[:, None]
        y = torch.round(y)
        y = torch.clamp(y, -1 - dst_type_scale, dst_type_scale).to(torch.int8)

        return y, scale

    @SupportedDevices(["Ascend910B", "Ascend910C"])
    def test_npu_swiglu_quant(self, device="npu"):
        batch_size = 4608
        hidden_size = 2048
        x_shape = (batch_size, hidden_size)
        input_data = np.random.randn(*x_shape).astype(np.float32)

        quant_mode = 1
        group_list_type = 0
        dst_type = torch.int8
        activate_left = False
        num_groups = 8
        offsets = np.random.randn(num_groups, hidden_size // 2).astype(np.float32)
        group_size = batch_size // num_groups
        group_index = [(i + 1) * group_size for i in range(num_groups)]
        smooth_scales = np.random.randn(num_groups, hidden_size // 2).astype(np.float32)

        device = "npu"
        npu_x = torch.tensor(input_data, dtype=torch.float32, device=device)
        npu_group_index = torch.tensor(group_index, dtype=torch.int32, device=device)
        npu_smooth_scales = torch.tensor(smooth_scales, dtype=torch.float32, device=device)
        npu_offsets = torch.tensor(offsets, dtype=torch.float32, device=device)
        result = torch_npu.npu_swiglu_quant(
            npu_x,
            smooth_scales=npu_smooth_scales,
            offsets=npu_offsets,
            group_index=npu_group_index,
            activate_left=activate_left,
            quant_mode=quant_mode,
            group_list_type=group_list_type,
            dst_type=dst_type
        )

        device = "cpu"
        cpu_x = torch.tensor(input_data, dtype=torch.float32, device=device)
        cpu_group_index = torch.tensor(group_index, dtype=torch.int32, device=device)
        cpu_smooth_scales = torch.tensor(smooth_scales, dtype=torch.float32, device=device)
        cpu_out = self.golden_swiglu_quant_torch(
            cpu_x,
            smooth_scales=cpu_smooth_scales,
            offsets=offsets,
            group_index=cpu_group_index,
            activate_left=activate_left,
            quant_mode=quant_mode,
            group_list_type=group_list_type,
            dst_type=dst_type
        )

        self.assertRtolEqual(cpu_out[0].numpy(), result[0].cpu().numpy())
        if quant_mode == 1:
            self.assertRtolEqual(cpu_out[1].numpy(), result[1].cpu().numpy())

if __name__ == "__main__":
    run_tests()