import numpy as np
import torch
import torch.nn.functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

torch.npu.config.allow_internal_format = False


class TestSwiGlu(TestCase):
    def get_golden(self, input_self_tensor, dim):
        def swiglu_v1(x):
            """0.1版本,FP32格式运算,最后输出转成BF16"""
            x = torch.chunk(x, 2, dim=dim)
            self_tensor = x[0].type(torch.float32)
            other = x[1].type(torch.float32)
            output = F.silu(self_tensor.npu()) * other.npu()
            return output.type(torch.bfloat16)

        output = swiglu_v1(input_self_tensor)
        return output

    @SupportedDevices(['Ascend910B'])
    def test_swiglu(self):
        shape = [8192, 1, 3904 * 2]
        dim = -1

        input_self_tensor = torch.rand(shape, device='cpu', dtype=torch.bfloat16).npu()
        torch.npu.synchronize()
        output = torch_npu.npu_swiglu(input_self_tensor, dim)
        torch.npu.synchronize()

        golden = self.get_golden(input_self_tensor, dim)
        self.assertRtolEqual(output.type(torch.float32), golden.type(torch.float32))

if __name__ == "__main__":
    run_tests()