"""
Add validation cases for torch.fx.Tracer APIs on NPU:
1. PyTorch community lacks sufficient and direct API validations for fx.Tracer.create_arg, so this file is added.
2. This file validates torch.fx.Tracer.create_arg (extendable for other tracer APIs such as call_module, getattr, etc.).
"""

import torch
import torch.fx as fx
import torch.nn as nn
import torch_npu
from torch.testing._internal.common_utils import run_tests, TestCase


class TestTracerCreateArg(TestCase):
    """
    Test suite for fx.Tracer.create_arg method.
    Validates that create_arg correctly processes tensors, containers,
    basic types, and NPU tensors during symbolic tracing.
    """

    def _get_placeholder_nodes(self, graph):
        """Helper to extract placeholder nodes from the graph"""
        return [node for node in graph.nodes if node.op == "placeholder"]

    def _get_constant_nodes(self, graph):
        """Helper to extract constant nodes from the graph"""
        return [
            node
            for node in graph.nodes
            if node.op == "get_attr" or "tensor_constant" in node.name
        ]

    def test_create_arg_tensor_to_constant_node(self):
        """Test that a torch.Tensor is converted to a constant graph node (get_attr)"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x + 1

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        tensor = torch.randn(2, 3).npu()
        result = tracer.create_arg(tensor)

        self.assertIsInstance(result, fx.Node)
        self.assertEqual(result.op, "get_attr")
        self.assertIn("_tensor_constant", result.name)

    def test_create_arg_tensor_node_in_graph(self):
        """Test that the created tensor node is added to the computation graph"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x + 1

        model = SimpleModule()
        tracer = fx.Tracer()
        graph = tracer.trace(model)

        tensor = torch.randn(2, 3).npu()
        result = tracer.create_arg(tensor)

        node_names = [node.name for node in graph.nodes]
        self.assertIn(result.name, node_names)

    def test_create_arg_tensor_list_to_node_list(self):
        """Test that a list of Tensors is converted to a list of graph nodes"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        tensor_list = [torch.randn(2, 2).npu(), torch.randn(2, 2).npu()]
        result = tracer.create_arg(tensor_list)

        self.assertIsInstance(result, list)
        self.assertEqual(len(result), 2)

        for elem in result:
            self.assertIsInstance(elem, fx.Node)
            self.assertEqual(elem.op, "get_attr")

    def test_create_arg_dict_to_node_dict(self):
        """Test that a dict of Tensors is converted to a dict of graph nodes"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        tensor_dict = {
            "first": torch.randn(2, 2).npu(),
            "second": torch.randn(2, 2).npu()
        }
        result = tracer.create_arg(tensor_dict)

        self.assertIsInstance(result, dict)
        self.assertEqual(set(result.keys()), {"first", "second"})

        for value in result.values():
            self.assertIsInstance(value, fx.Node)
            self.assertEqual(value.op, "get_attr")

    def test_create_arg_tuple_to_node_tuple(self):
        """Test that a tuple of Tensors is converted to a tuple of graph nodes"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        tensor_tuple = (torch.randn(2, 2).npu(), torch.randn(2, 2).npu())
        result = tracer.create_arg(tensor_tuple)

        self.assertIsInstance(result, tuple)
        self.assertEqual(len(result), 2)

        for elem in result:
            self.assertIsInstance(elem, fx.Node)
            self.assertEqual(elem.op, "get_attr")

    def test_create_arg_deeply_nested_container(self):
        """Test recursive processing of deeply nested Tensor containers"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        nested_input = [
            {"data": [torch.randn(2, 3).npu(), torch.randn(2, 3).npu()]},
            torch.randn(2, 3).npu(),
        ]
        result = tracer.create_arg(nested_input)

        self.assertIsInstance(result, list)
        self.assertEqual(len(result), 2)

        dict_elem = result[0]
        self.assertIsInstance(dict_elem, dict)
        self.assertIn("data", dict_elem)

        list_elem = dict_elem["data"]
        self.assertIsInstance(list_elem, list)
        self.assertEqual(len(list_elem), 2)

        for node_elem in list_elem:
            self.assertIsInstance(node_elem, fx.Node)

        self.assertIsInstance(result[1], fx.Node)

    def test_create_arg_with_none(self):
        """Test that None is passed through without modification"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        result = tracer.create_arg(None)
        self.assertIsNone(result)

    def test_create_arg_with_basic_types(self):
        """Test that basic types (int, float, str, bool) are preserved"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        self.assertEqual(tracer.create_arg(42), 42)
        self.assertEqual(tracer.create_arg(3.14), 3.14)
        self.assertEqual(tracer.create_arg("hello"), "hello")
        self.assertTrue(tracer.create_arg(True))

    def test_create_arg_npu_tensor_consistency(self):
        """Test that NPU tensors are processed the same as CPU tensors"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x + 1

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        cpu_tensor = torch.randn(2, 3)
        npu_tensor = torch.randn(2, 3).npu()

        cpu_result = tracer.create_arg(cpu_tensor)
        npu_result = tracer.create_arg(npu_tensor)

        self.assertIsInstance(cpu_result, fx.Node)
        self.assertIsInstance(npu_result, fx.Node)
        self.assertEqual(cpu_result.op, npu_result.op)

    def test_create_arg_with_custom_tracer_override(self):
        """Test that custom Tracer can override create_arg for custom types"""

        class CustomType:
            def __init__(self, value):
                self.value = value

        class CustomTracer(fx.Tracer):
            def create_arg(self, a):
                if isinstance(a, CustomType):
                    return super().create_arg(a.value)
                return super().create_arg(a)

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x + 1

        model = SimpleModule()
        tracer = CustomTracer()
        tracer.trace(model)

        custom_input = CustomType(torch.randn(2, 3).npu())
        result = tracer.create_arg(custom_input)
        self.assertIsInstance(result, fx.Node)

    def test_create_arg_list_mixed_types(self):
        """Test mixed-type list processing (Tensor + basic types)"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        mixed_list = [
            torch.randn(2, 2).npu(),
            42,
            "hello",
            torch.randn(2, 2).npu()
        ]
        result = tracer.create_arg(mixed_list)

        self.assertIsInstance(result, list)
        self.assertEqual(len(result), 4)

        self.assertIsInstance(result[0], fx.Node)
        self.assertEqual(result[1], 42)
        self.assertEqual(result[2], "hello")
        self.assertIsInstance(result[3], fx.Node)

    def test_create_arg_dict_mixed_values(self):
        """Test mixed-type dict processing (Tensor + basic types + nested list)"""

        class SimpleModule(nn.Module):
            def forward(self, x):
                return x

        model = SimpleModule()
        tracer = fx.Tracer()
        tracer.trace(model)

        mixed_dict = {
            "tensor": torch.randn(2, 2).npu(),
            "int": 42,
            "str": "hello",
            "list": [torch.randn(2, 2).npu(), torch.randn(2, 2).npu()],
        }
        result = tracer.create_arg(mixed_dict)

        self.assertIsInstance(result, dict)
        self.assertIsInstance(result["tensor"], fx.Node)
        self.assertEqual(result["int"], 42)
        self.assertEqual(result["str"], "hello")
        self.assertIsInstance(result["list"], list)
        self.assertEqual(len(result["list"]), 2)
        self.assertIsInstance(result["list"][0], fx.Node)


if __name__ == "__main__":
    run_tests()