from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.npu import (
    NpuGraphOpHandler,
    register_npu_graph_handler,
)
from torch_npu.npu._npugraph_handlers.npugraph_handler import _NPU_GRAPH_OP_HANDLERS


class TestNpuGraphHandlerRegistry(TestCase):

    def setUp(self):
        self._snapshot = dict(_NPU_GRAPH_OP_HANDLERS)

    def tearDown(self):
        _NPU_GRAPH_OP_HANDLERS.clear()
        _NPU_GRAPH_OP_HANDLERS.update(self._snapshot)

    def test_register_single_name(self):
        @register_npu_graph_handler("test_op_single")
        class _H(NpuGraphOpHandler):
            pass

        self.assertIn("test_op_single", _NPU_GRAPH_OP_HANDLERS)
        self.assertIs(_NPU_GRAPH_OP_HANDLERS["test_op_single"], _H)

    def test_register_multiple_names(self):
        @register_npu_graph_handler(["test_op_a", "test_op_a.default"])
        class _H(NpuGraphOpHandler):
            pass

        self.assertIn("test_op_a", _NPU_GRAPH_OP_HANDLERS)
        self.assertIn("test_op_a.default", _NPU_GRAPH_OP_HANDLERS)
        self.assertIs(_NPU_GRAPH_OP_HANDLERS["test_op_a"], _H)


class TestNpuGraphHandlerBuiltinRegistration(TestCase):

    EXPECTED_OPS = [
        "npu_fused_infer_attention_score",
        "npu_fused_infer_attention_score.default",
        "npu_fused_infer_attention_score.out",
        "npu_fused_infer_attention_score_v2",
        "npu_fused_infer_attention_score_v2.default",
        "npu_fused_infer_attention_score_v2.out",
        "_npu_paged_attention.default",
        "npu_multi_head_latent_attention.out",
    ]

    def test_all_expected_ops_registered(self):
        for op_name in self.EXPECTED_OPS:
            with self.subTest(op_name=op_name):
                self.assertIn(
                    op_name,
                    _NPU_GRAPH_OP_HANDLERS,
                    f"Expected handler for '{op_name}' not found in registry",
                )

    def test_ifa_update_specs_cover_actual_seq_args(self):
        for op_name in (
            "npu_fused_infer_attention_score",
            "npu_fused_infer_attention_score.default",
            "npu_fused_infer_attention_score.out",
        ):
            with self.subTest(op_name=op_name):
                specs = _NPU_GRAPH_OP_HANDLERS[op_name].get_update_specs(op_name)
                self.assertIn(("arg", 5, "actual_seq_lengths"), specs)
                self.assertIn(("arg", 6, "actual_seq_lengths_kv"), specs)

    def test_ifa_v2_update_specs_cover_actual_seq_args(self):
        for op_name in (
            "npu_fused_infer_attention_score_v2",
            "npu_fused_infer_attention_score_v2.default",
            "npu_fused_infer_attention_score_v2.out",
        ):
            with self.subTest(op_name=op_name):
                specs = _NPU_GRAPH_OP_HANDLERS[op_name].get_update_specs(op_name)
                self.assertIn(("arg", 7, "actual_seq_qlen"), specs)
                self.assertIn(("arg", 8, "actual_seq_kvlen"), specs)


if __name__ == "__main__":
    run_tests()