import numpy as np
import torch
import torch.nn.functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

torch.npu.config.allow_internal_format = False


class TestSwigluGroupQuantBackward(TestCase):

    def golden_swiglu_group_quant_backward(self, grad_y, x, weight=None, y_origin=None, group_index=None, clamp_limit=0.0):
        input_dtype = x.dtype
        grad_y = grad_y.float()
        x = x.float()
        if y_origin is not None:
            y_origin = y_origin.float()

        grad_weight = None
        trunc = 0

        if group_index is not None:
            for group in group_index:
                trunc += group

        if weight is not None:
            grad_weight = torch.sum(grad_y * y_origin, dim=-1, keepdim=True)
            if group_index is not None:
                original_gw_shape = grad_weight.shape
                grad_weight = grad_weight.reshape([-1, 1])
                num_rows = grad_weight.shape[0]
                row = torch.arange(num_rows, device=grad_weight.device)
                mask = (row < trunc).unsqueeze(-1).float()
                grad_weight = grad_weight * mask
                grad_weight = grad_weight.reshape(original_gw_shape)
            grad_y0 = grad_y * weight.float()
        else:
            grad_y0 = grad_y

        original_shape = x.shape
        x = x.reshape([-1, x.shape[-1]])
        H = x.shape[-1] // 2
        x0 = x[:, :H]
        x1 = x[:, H:]

        x0_truncated = x0
        x1_truncated = x1

        if clamp_limit != 0:
            x0 = torch.clamp(x0, max=clamp_limit)
            x1 = torch.clamp(x1, -clamp_limit, clamp_limit)

        sigmoid_x0 = torch.sigmoid(x0)
        silu_x0 = x0 * sigmoid_x0
        silu_grad_x0 = sigmoid_x0 * (1 + x0 * (1 - sigmoid_x0))

        grad_y0_flat = grad_y0.reshape([-1, H])
        grad_x0 = grad_y0_flat * x1 * silu_grad_x0
        grad_x1 = grad_y0_flat * silu_x0

        if clamp_limit != 0:
            mask_x0 = (x0_truncated < clamp_limit).float()
            mask_x1 = ((-clamp_limit < x1_truncated) & (x1_truncated < clamp_limit)).float()
            grad_x0 = grad_x0 * mask_x0
            grad_x1 = grad_x1 * mask_x1

        if group_index is not None:
            num_rows = grad_x0.shape[0]
            row = torch.arange(num_rows, device=grad_x0.device)
            mask = (row < trunc).unsqueeze(-1).float()
            grad_x0 = grad_x0 * mask
            grad_x1 = grad_x1 * mask

        grad_x = torch.cat([grad_x0, grad_x1], dim=-1)
        grad_x = grad_x.reshape(original_shape)
        grad_x = grad_x.to(input_dtype)
        return grad_x, grad_weight

    @SupportedDevices(['Ascend950'])
    def test_swiglu_group_quant_backward(self):
        grad_y = torch.randn([4, 8], dtype=torch.float32)
        x = torch.randn([4, 16], dtype=torch.float32)
        weight = torch.randn([4, 1], dtype=torch.float32)
        y_origin = torch.randn([4, 8], dtype=torch.float32)
        group_index = None
        clamp_limit = 5.0

        grad_x_golden, grad_weight_golden = self.golden_swiglu_group_quant_backward(grad_y,
                                                                                    x,
                                                                                    weight=weight,
                                                                                    y_origin=y_origin,
                                                                                    group_index=group_index,
                                                                                    clamp_limit=clamp_limit)

        grad_y_npu = grad_y.npu()
        x_npu = x.npu()
        weight_npu = weight.npu()
        y_origin_npu = y_origin.npu()

        grad_x_npu, grad_weight_npu = torch_npu.npu_swiglu_group_quant_backward(grad_y_npu,
                                                                                x_npu,
                                                                                weight=weight_npu,
                                                                                y_origin=y_origin_npu,
                                                                                group_index=group_index,
                                                                                clamp_limit=clamp_limit)
        grad_x_cpu = grad_x_npu.cpu()
        grad_weight_cpu = grad_weight_npu.cpu()

        self.assertRtolEqual(grad_x_golden.type(torch.float32), grad_x_cpu.type(torch.float32))
        self.assertRtolEqual(grad_weight_golden.type(torch.float32), grad_weight_cpu.type(torch.float32))


if __name__ == "__main__":
    run_tests()