import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices

torch.npu.config.allow_internal_format = False


class TestNPUGelu(TestCase):

    def get_golden(self, input_tensor, approximate = "none"):
        last_dim = input_tensor.shape[-1]
        if last_dim % 2 == 1:
            return "shape error"
        d = last_dim // 2
        x1 = input_tensor[..., :d]
        x2 = input_tensor[..., d:]
        m = torch.nn.GELU(approximate)
        x1 = m(x1)
        output = x1 * x2
        return output

    @SupportedDevices(['Ascend910B'])
    def test_npu_gelu_all_modes(self):
        shape = [100, 400]
        test_combinations = [
            ("none", torch.float16),
            ("none", torch.float32),
            ("tanh", torch.float16),
            ("tanh", torch.float32)
        ]

        for mode, dtype in test_combinations:
            input_tensor = torch.rand(shape, dtype=dtype).npu()
            output = torch_npu.npu_gelu_mul(input_tensor, approximate=mode)
            golden = self.get_golden(input_tensor.cpu(), mode)
            self.assertRtolEqual(golden, output.cpu())

if __name__ == "__main__":
    run_tests()