import numpy as np
import torch
import torch.nn.functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

torch.npu.config.allow_internal_format = False


class TestSwiGluBackward(TestCase):

    def swish(self, beta, x):
        return x * torch.sigmoid(beta * x)

    def swish_backward(self, beta, x):
        return torch.sigmoid(beta * x) + x * (1 - torch.sigmoid(beta * x)) * torch.sigmoid(beta * x) * beta

    def get_golden(self, tensor_gradout, input_self_tensor, dim):

        def swiglu_backward_v1(x):
            """0.1版本,FP32格式运算,最后输出转成BF16"""
            beta_value = 1.0
            inTensors = torch.chunk(x, 2, dim=dim)
            tensor_self_float = inTensors[0].type(torch.float)
            tensor_other_float = inTensors[1].type(torch.float)
            tensor_gradout_float = tensor_gradout.type(torch.float)
            torch.mul(torch.relu(tensor_self_float), tensor_other_float)
            tensor_out1 = torch.mul(torch.mul(tensor_other_float, self.swish_backward(beta_value, tensor_self_float)),
                                    tensor_gradout_float)
            tensor_out2 = torch.mul(tensor_gradout_float, self.swish(beta_value, tensor_self_float))
            tensor_out_float = torch.cat((tensor_out1, tensor_out2), dim=-1)
            return tensor_out_float.type(torch.bfloat16)

        output = swiglu_backward_v1(input_self_tensor)
        return output

    @SupportedDevices(['Ascend910B'])
    def test_swiglu_backward(self):
        shape = [8192, 1, 3904 * 2]
        grad_shape = [8192, 1, 3904]
        dim = -1
        grad_out = torch.rand(grad_shape, device='cpu', dtype=torch.bfloat16)
        input_self_tensor = torch.rand(shape, device='cpu', dtype=torch.bfloat16)

        golden = self.get_golden(grad_out, input_self_tensor, dim)
        torch.npu.synchronize()

        input_self_tensor_npu = input_self_tensor.npu()
        input_self_tensor_npu.requires_grad_(True)
        input_self_tensor_npu.retain_grad()
        grad_out_npu = grad_out.npu()

        output_forward = torch_npu.npu_swiglu(input_self_tensor_npu, dim)
        output_forward.backward([grad_out_npu])
        result = input_self_tensor_npu.grad.cpu()
        torch.npu.synchronize()

        self.assertRtolEqual(golden.type(torch.float32), result.type(torch.float32))


if __name__ == "__main__":
    run_tests()