import os
import unittest
import torch

from mindiesd.compilation import MindieSDBackend
from tests.compilation.test_bench_utils import benchmark


class GeluPatternModel(torch.nn.Module):
    def __init__(self, approximate="tanh"):
        super().__init__()
        self.gelu = torch.nn.GELU(approximate=approximate)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gelu(x)


@unittest.skipIf(os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU", "Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.")
class TestGeluCompilationCase(unittest.TestCase):
    def _run_test_and_measure_time(self, model, x):
        compiled_model = torch.compile(model, backend=MindieSDBackend())
        compiled_model(x)
        torch.npu.synchronize()

        compiled_time = benchmark(compiled_model, (x,))
        original_time = benchmark(model, (x,))

        output_compiled = compiled_model(x)
        output_original = model(x)

        output_compiled = output_compiled.reshape(1, -1).to(torch.float32)
        output_original = output_original.reshape(1, -1).to(torch.float32)
        self.assertGreater(torch.cosine_similarity(output_compiled, output_original)[0], 2**-7, msg="模式替换后输出不一致!")
        self.assertLess(compiled_time, original_time, msg="compiled={:.6f}s >= original={:.6f}s".format(compiled_time, original_time))

    def test_gelu_pattern_tanh_approx_bfloat16(self):
        model = GeluPatternModel(approximate="tanh")
        x = torch.randn(4, 4608, 12288, dtype=torch.bfloat16, device="npu")

        self._run_test_and_measure_time(model, x)

if __name__ == "__main__":
    unittest.main()