# Owner(s): ["module: dynamo"]
import dataclasses
import functools
import unittest
from typing import List

import torch
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import same
from torch._functorch.aot_autograd import aot_module_simplified
from torch._inductor.pattern_matcher import Match
from torch._subclasses.fake_tensor import FakeTensorMode

import torch_npu

DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]


@unittest.skipIf(DEVICE_NAME == "Ascend910A", "capture is not supported on 910A, skip this ut.")
class TestNpuGraphEx(TestCase):
    def test_backend(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, y):
                return x + y

        model = Model().npu()
        compiled_model = torch.compile(model, backend="npugraph_ex", options={"clone_input": False}, fullgraph=True, dynamic=False)
        x = torch.ones(1, dtype=torch.int32, device="npu")
        y = torch.ones(1, dtype=torch.int32, device="npu")
        z = compiled_model(x, y)
        self.assertEqual(z.item(), 2)

    def test_compile_fx(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x, y):
                return x + y

        def custom_compiler(gm: torch.fx.GraphModule, example_inputs):
            compiled_graph = torch.npu.npugraph_ex.compile_fx(gm, example_inputs)
            return compiled_graph

        def custom_compiler_with_options(gm: torch.fx.GraphModule, example_inputs):
            test_kwargs = {
                "clone_input": False
            }
            compiled_graph = torch.npu.npugraph_ex.compile_fx(gm, example_inputs, test_kwargs)
            return compiled_graph

        def my_backend(gm: torch.fx.GraphModule, example_inputs):
            return aot_module_simplified(gm, example_inputs, fw_compiler=custom_compiler)

        def my_backend_with_options(gm: torch.fx.GraphModule, example_inputs):
            return aot_module_simplified(gm, example_inputs, fw_compiler=custom_compiler_with_options)

        model = Model().npu()
        compiled_model = torch.compile(model, backend=my_backend, fullgraph=True, dynamic=False)
        x = torch.ones(1, dtype=torch.int32, device="npu")
        y = torch.ones(1, dtype=torch.int32, device="npu")
        z = compiled_model(x, y)
        self.assertEqual(z.item(), 2)
        compiled_model = torch.compile(model, backend=my_backend_with_options, fullgraph=True, dynamic=False)
        z = compiled_model(x, y)
        self.assertEqual(z.item(), 2)

    def test_cache_compile(self):
        @dataclasses.dataclass
        class InputMeta:
            data: torch.Tensor
            is_prompt: bool

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear1 = torch.nn.Linear(2, 1)
                self.linear2 = torch.nn.Linear(2, 1)
                for param in self.parameters():
                    torch.nn.init.ones_(param)
                self.cached_prompt = torch.npu.npugraph_ex.inference.cache_compile(self.prompt)
                self.cached_decode = torch.npu.npugraph_ex.inference.cache_compile(self.decode)

            def forward(self, x: InputMeta, kv: List[torch.Tensor]):
                if x.is_prompt:
                    return self.cached_prompt(x, kv)
                return self.cached_decode(x, kv)

            def _forward(self, x, kv):
                return self.linear2(x.data) + self.linear2(kv[0])

            def prompt(self, x, y):
                return self._forward(x, y)

            def decode(self, x, y):
                return self._forward(x, y)

        x = InputMeta(data=torch.ones(2, 2).npu(), is_prompt=True)
        kv = [torch.ones(2, 2).npu()]

        model = Model().npu()
        res_prompt = model(x, kv)

        x.is_prompt = False
        res_decode = model(x, kv)
        res = torch.empty(2, 1).npu().fill_(6.0)
        self.assertTrue(same(res, res_prompt))
        self.assertTrue(same(res, res_decode))

    def test_register_replacement(self):
        def search_fn(x1, x2, gamma):
            x_out = torch.add(x1, x2)
            y, _ = torch_npu.npu_rms_norm(x_out, gamma)
            return y, x_out

        def replace_fn(x1, x2, gamma):
            y, _, x_out = torch_npu.npu_add_rms_norm(x1, x2, gamma)
            return y, x_out

        def extra_check(match: Match):
            x1 = match.kwargs.get("x1")

            if x1 is None:
                return False
            if not hasattr(x1, "meta") or "val" not in x1.meta:
                return False

            a_shape = x1.meta["val"].shape
            return a_shape[-1] == 7168

        fake_mode = FakeTensorMode()
        with fake_mode:
            input_tensor = functools.partial(torch.empty, (1, 1, 2), dtype=torch.float16, device="npu")
            kwargs_tensor = functools.partial(torch.empty, 2, dtype=torch.float16, device="npu")

            torch.npu.npugraph_ex.register_replacement(
                search_fn=search_fn,
                replace_fn=replace_fn,
                example_inputs=(input_tensor(), input_tensor(), kwargs_tensor()),
                extra_check=extra_check
            )

        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, data1, data2, gamma):
                x_out = torch.add(data1, data2)
                y, _ = torch_npu.npu_rms_norm(x_out, gamma)

                abs_01 = torch.abs(y)
                sqrt_01 = torch.sqrt(x_out)
                return abs_01, sqrt_01

        model = Model().npu()

        x1 = torch.randn(1, 1, 7168, dtype=torch.float16, device="npu")
        x2 = torch.randn(1, 1, 7168, dtype=torch.float16, device="npu")
        gamma = torch.randn(7168, dtype=torch.float16, device="npu")

        compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=False)
        res = compiled_model(x1, x2, gamma)
        self.assertEqual(res[0].shape[2], 7168)
        self.assertEqual(res[1].shape[2], 7168)

    def test_limit_core_num(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, in1, in2, in3, in4):
                with torch.npu.npugraph_ex.scope.limit_core_num(4, 5):
                    mm_result = torch.mm(in3, in4)
                    add_result = torch.add(in1, in2)
                mm1_result = torch.mm(in3, in4)
                return add_result, mm_result, mm1_result

        model = Model().npu()
        compiled_model = torch.compile(model, backend="npugraph_ex", fullgraph=True, dynamic=False)
        in1 = torch.randn(1000, 1000, dtype=torch.float16, device="npu")
        in2 = torch.randn(1000, 1000, dtype=torch.float16, device="npu")
        in3 = torch.randn(1000, 1000, dtype=torch.float16, device="npu")
        in4 = torch.randn(1000, 1000, dtype=torch.float16, device="npu")
        res = compiled_model(in1, in2, in3, in4)


if __name__ == "__main__":
    from torch._dynamo.test_case import run_tests

    run_tests()