"""
Add validation cases for torch.fx.Tracer/Transformer APIs on NPU:

1. test/test_fx.py from PyTorch community lacks direct test cases for these APIs:
   - torch.fx.Tracer.trace
   - torch.fx.Tracer.path_of_module
   - torch.fx.Tracer.iter
   - torch.fx.Tracer.keys
   - torch.fx.Tracer.proxy
   - torch.fx.Tracer.to_bool
   - torch.fx.Transformer.call_function
   - torch.fx.Transformer.call_module
   - torch.fx.Transformer.get_attr
   - torch.fx.Transformer.placeholder
   - torch.fx.Tracer.getattr

2. This file validates the core functionality of these APIs on NPU environment.
"""

import torch

from torch.fx import Tracer, symbolic_trace, Transformer, GraphModule
from torch.fx.proxy import Proxy, TraceError
from torch.testing._internal.common_utils import TestCase, run_tests


class TestTracerTrace(TestCase):
    """Test torch.fx.Tracer.trace method."""

    def test_trace_function(self):
        """Verify trace can symbolically trace a function."""
        def fn(x, y):
            return x + y

        tracer = Tracer()
        graph = tracer.trace(fn)

        graph.lint()
        self.assertIsNotNone(graph)

    def test_trace_module(self):
        """Verify trace can symbolically trace a Module."""
        class MyModule(torch.nn.Module):
            def forward(self, x):
                return torch.relu(x)

        tracer = Tracer()
        graph = tracer.trace(MyModule())

        graph.lint()
        self.assertIsNotNone(graph)


class TestTracerPathOfModule(TestCase):
    """Test torch.fx.Tracer.path_of_module method."""

    def test_path_of_module(self):
        """Verify path_of_module can get submodule qualified path."""
        class SubModule(torch.nn.Module):
            def forward(self, x):
                return x

        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.sub = SubModule()

            def forward(self, x):
                return self.sub(x)

        tracer = Tracer()
        mod = MyModule()
        tracer.root = mod

        tracer.submodule_paths = {}
        for name, submod in mod.named_modules():
            if submod is not mod:
                tracer.submodule_paths[submod] = name

        path = tracer.path_of_module(mod.sub)
        self.assertEqual(path, "sub")


class TestTracerIter(TestCase):
    """Test torch.fx.Tracer.iter method."""

    def test_iter_raises_trace_error(self):
        """Verify iter raises TraceError for Proxy objects."""
        graph = torch.fx.Graph()
        node = graph.placeholder("x")

        tracer = Tracer()
        tracer.graph = graph
        proxy = tracer.proxy(node)

        self.assertRaises(TraceError, tracer.iter, proxy)


class TestTracerKeys(TestCase):
    """Test torch.fx.Tracer.keys method."""

    def test_keys_method_callable(self):
        """Verify keys method exists and is callable on Tracer."""
        tracer = Tracer()
        self.assertTrue(hasattr(tracer, 'keys'))
        self.assertTrue(callable(tracer.keys))


class TestTracerProxy(TestCase):
    """Test torch.fx.Tracer.proxy method."""

    def test_proxy_creates_proxy_object(self):
        """Verify proxy wraps Node into a Proxy object."""
        graph = torch.fx.Graph()
        node = graph.placeholder("x")

        tracer = Tracer()
        tracer.graph = graph

        proxy = tracer.proxy(node)
        self.assertIsInstance(proxy, Proxy)
        self.assertEqual(proxy.node, node)


class TestTracerToBool(TestCase):
    """Test torch.fx.Tracer.to_bool method."""

    def test_to_bool_raises_trace_error(self):
        """Verify to_bool raises TraceError for Proxy objects."""
        graph = torch.fx.Graph()
        node = graph.placeholder("x")

        tracer = Tracer()
        tracer.graph = graph
        proxy = tracer.proxy(node)

        self.assertRaises(TraceError, tracer.to_bool, proxy)


class TestTransformerCallFunction(TestCase):
    """Test torch.fx.Transformer.call_function method."""

    def test_call_function(self):
        """Verify Transformer handles call_function nodes correctly."""
        class MyModule(torch.nn.Module):
            def forward(self, x):
                return torch.relu(x)

        gm = symbolic_trace(MyModule())
        transformed = Transformer(gm).transform()

        input = torch.randn(4, 4, device="npu")
        result = transformed(input)
        self.assertEqual(result.shape, input.shape)


class TestTransformerCallModule(TestCase):
    """Test torch.fx.Transformer.call_module method."""

    def test_call_module(self):
        """Verify Transformer handles call_module nodes correctly."""
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, x):
                return self.linear(x)

        mod = MyModule().to("npu")
        gm = symbolic_trace(mod)
        transformed = Transformer(gm).transform()

        input = torch.randn(2, 4, device="npu")
        self.assertEqual(transformed(input).shape, input.shape)


class TestTransformerGetAttr(TestCase):
    """Test torch.fx.Transformer.get_attr method."""

    def test_get_attr(self):
        """Verify Transformer handles get_attr nodes correctly."""
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.weight = torch.nn.Parameter(torch.ones(4, 4, device="npu"))

            def forward(self, x):
                return x + self.weight

        tracer = Tracer()
        graph = tracer.trace(MyModule())
        gm = GraphModule(tracer.root, graph)
        transformed = Transformer(gm).transform()

        input = torch.randn(4, 4, device="npu")
        result = transformed(input)
        self.assertEqual(result.shape, input.shape)


class TestTransformerPlaceholder(TestCase):
    """Test torch.fx.Transformer.placeholder method."""

    def test_placeholder(self):
        """Verify Transformer handles placeholder nodes correctly."""
        class MyModule(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        gm = symbolic_trace(MyModule())
        transformed = Transformer(gm).transform()

        x = torch.randn(4, 4, device="npu")
        y = torch.randn(4, 4, device="npu")
        result = transformed(x, y)
        self.assertEqual(result.shape, x.shape)


class TestTracerGetAttr(TestCase):
    """Test torch.fx.Tracer.getattr method."""

    def test_getattr_parameter(self):
        """Verify getattr correctly captures nn.Parameter in graph and computes correctly on NPU."""
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.param = torch.nn.Parameter(torch.ones(4, 4, device="npu"))

            def forward(self, x):
                p = self.param
                return x + p

        mod = MyModule()
        tracer = Tracer()
        graph = tracer.trace(mod)
        gm = GraphModule(tracer.root, graph)

        # Verify graph structure
        get_attr_nodes = [n for n in graph.nodes if n.op == 'get_attr']
        self.assertEqual(len(get_attr_nodes), 1)
        self.assertEqual(get_attr_nodes[0].target, 'param')

        # Verify computation on NPU
        input = torch.randn(4, 4, device="npu")
        expected = mod(input)
        actual = gm(input)
        self.assertEqual(actual.shape, expected.shape)
        self.assertTrue(torch.allclose(actual, expected))

    def test_getattr_buffer(self):
        """Verify getattr correctly captures buffer in graph and computes correctly on NPU."""
        class MyModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.register_buffer('buf', torch.zeros(4, 4, device="npu"))

            def forward(self, x):
                b = self.buf
                return x - b

        mod = MyModule()
        tracer = Tracer()
        graph = tracer.trace(mod)
        gm = GraphModule(tracer.root, graph)

        # Verify graph structure
        get_attr_nodes = [n for n in graph.nodes if n.op == 'get_attr']
        self.assertEqual(len(get_attr_nodes), 1)
        self.assertEqual(get_attr_nodes[0].target, 'buf')

        # Verify computation on NPU
        input = torch.randn(4, 4, device="npu")
        expected = mod(input)
        actual = gm(input)
        self.assertEqual(actual.shape, expected.shape)
        self.assertTrue(torch.allclose(actual, expected))


if __name__ == "__main__":
    run_tests()