import os
import unittest
from unittest.mock import MagicMock, patch
from msprobe.visualization.builder.graph_builder import GraphBuilder, Graph, GraphExportConfig
from msprobe.visualization.graph.node_op import NodeOp
from msprobe.visualization.graph.base_node import BaseNode


class TestGraphBuilder(unittest.TestCase):

    def setUp(self):
        self.construct_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "construct.json")
        self.construct_path_empty = os.path.join(os.path.dirname(os.path.realpath(__file__)), "construct_empty.json")
        self.data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "dump.json")
        self.stack_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "stack.json")
        self.model_name = "TestModel"
        self.graph = Graph(self.model_name)
        self.graph_b = Graph(self.model_name)
        self.config = GraphExportConfig(self.graph, self.graph_b)
        self.construct_dict = {
            "Tensor1": "Module1",
            "Module1": None
        }
        self.data_dict = {
            "Module1": {"data": "data for Module1"},
            "Tensor1": {"data": "data for Tensor1"}
        }
        self.stack_dict = {}

    def test_build(self):
        graph = GraphBuilder.build(self.construct_path, self.data_path, self.stack_path, self.model_name)
        self.assertIsNotNone(graph)
        self.assertIsInstance(graph, Graph)
        self.assertEqual(len(graph.node_map), 3)

        with self.assertRaises(RuntimeError):
            GraphBuilder.build(self.construct_path_empty, self.data_path, self.stack_path, self.model_name)

    @patch('msprobe.visualization.graph.node_op.NodeOp.get_node_op')
    @patch('msprobe.visualization.builder.msprobe_adapter.get_input_output', return_value=([], []))
    def test__init_nodes(self, mock_get_input_output, mock_get_node_op):
        GraphBuilder._init_nodes(self.graph, self.construct_dict, self.data_dict, self.stack_dict)
        mock_get_node_op.assert_any_call("Tensor1")
        mock_get_node_op.assert_any_call("Module1")
        self.assertIs(self.graph.root, self.graph.get_node("TestModel"))

    def test__create_or_get_node(self):
        node_op = MagicMock()
        data_dict = {"node1": {}}
        stack_dict = {}
        node = GraphBuilder._create_or_get_node(self.graph, [data_dict, stack_dict], node_op, "node1")
        self.assertIn("node1", self.graph.node_map)
        self.assertEqual(node.input_data, {})
        self.assertEqual(node.output_data, {})

    def test__handle_backward_upnode_missing(self):
        construct_dict = {'Module.module.a.forward.0': 'Module.root.forward.0', 'Module.module.a.backward.0': None,
                          'Module.root.forward.0': None, 'Module.root.backward.0': None,
                          'Module.module.b.forward.0': 'Module.root.forward.0',
                          'Module.module.b.backward.0': 'Module.root.backward.0', 'Module.module.c.backward.0': None}
        node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward.0', None)
        self.assertEqual(node_id_a, 'Module.root.backward.0')
        node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward.0',
                                                                 'Module.root.backward.0')
        self.assertEqual(node_id_b, 'Module.root.backward.0')
        node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None)
        self.assertIsNone(node_id_c)
        construct_dict = {'Module.module.a.forward': 'Module.root.forward', 'Module.module.a.backward': None,
                          'Module.root.forward': None, 'Module.root.backward': None,
                          'Module.module.b.forward': 'Module.root.forward',
                          'Module.module.b.backward': 'Module.root.backward', 'Module.module.c.backward': None}
        node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward', None)
        self.assertEqual(node_id_a, 'Module.root.backward')
        node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward',
                                                                 'Module.root.backward')
        self.assertEqual(node_id_b, 'Module.root.backward')
        node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward', None)
        self.assertIsNone(node_id_c)

    def test__collect_apis_between_modules_only_apis(self):
        graph = Graph('TestNet')
        graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0')]
        GraphBuilder._collect_apis_between_modules(graph)
        self.assertEqual(len(graph.root.subnodes), 1)
        self.assertEqual(graph.root.subnodes[0].op, NodeOp.api_collection)
        self.assertEqual(len(graph.root.subnodes[0].subnodes), 2)
        self.assertEqual(graph.root.subnodes[0].id, 'Apis_Between_Modules.0')

    def test__collect_apis_between_modules_mixed_nodes(self):
        graph = Graph('TestNet')
        graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.module, 'Module.a.0'),
                               BaseNode(NodeOp.module, 'Module.b.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0'),
                               BaseNode(NodeOp.function_api, 'Tensor.c.0'), BaseNode(NodeOp.module, 'Module.a.1')]
        GraphBuilder._collect_apis_between_modules(graph)
        self.assertEqual(len(graph.root.subnodes), 5)
        self.assertEqual(graph.root.subnodes[0].op, NodeOp.function_api)
        self.assertEqual(graph.root.subnodes[1].op, NodeOp.module)
        self.assertEqual(graph.root.subnodes[3].op, NodeOp.api_collection)
        self.assertEqual(len(graph.root.subnodes[3].subnodes), 2)
        self.assertEqual(graph.root.subnodes[3].id, 'Apis_Between_Modules.0')

    def test__collect_apis_between_modules_only_modules(self):
        graph = Graph('TestNet')
        graph.root.subnodes = [BaseNode(NodeOp.module, 'Module.a.0'), BaseNode(NodeOp.module, 'Module.b.0'),
                               BaseNode(NodeOp.module, 'Module.a.1')]
        GraphBuilder._collect_apis_between_modules(graph)
        self.assertEqual(len(graph.root.subnodes), 3)
        self.assertEqual(graph.root.subnodes[0].op, NodeOp.module)
        self.assertEqual(graph.root.subnodes[1].op, NodeOp.module)
        self.assertEqual(graph.root.subnodes[2].op, NodeOp.module)
        self.assertEqual(len(graph.root.subnodes[0].subnodes), 0)
        self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0')

    def test_add_parameters_grad(self):
        graph = Graph('TestNet')
        graph.add_node(NodeOp.module, 'Module.a.backward.0', graph.root)
        graph.add_node(NodeOp.module, 'Module.b.backward.0', graph.root)
        graph.add_node(NodeOp.module, 'Module.a.backward.1', graph.root)
        graph.add_node(NodeOp.module, 'Module.aa.backward.0', graph.get_node('Module.a.backward.0'))
        graph.add_node(NodeOp.module, 'Module.aaa.backward.0', graph.get_node('Module.a.backward.0'))
        graph.add_node(NodeOp.module, 'Module.aa.backward.1', graph.get_node('Module.a.backward.1'))
        graph.add_node(NodeOp.module, 'Module.aaa.backward.1', graph.get_node('Module.a.backward.1'))

        data_dict = {'Module.a.parameters_grad': {}, 'Module.aaa.parameters_grad': {}}
        GraphBuilder._add_parameters_grad(graph, data_dict)
        root_nodes_id = [node.id for node in graph.get_node('TestNet').subnodes]
        sub_nodes_id0 = [node.id for node in graph.get_node('Module.a.backward.0').subnodes]
        sub_nodes_id1 = [node.id for node in graph.get_node('Module.a.backward.1').subnodes]

        self.assertEqual(root_nodes_id[-1], 'Module.a.backward.1')
        self.assertEqual(sub_nodes_id0[-1], 'Module.aaa.backward.0')
        self.assertEqual(sub_nodes_id1[-1], 'Module.a.parameters_grad')

    def test_handle_backward_inplace(self):
        construct_dict = {'Module.module.Float16Model.forward.0': None,
                          'Module.module.layer1.BasicBlock.forward.0': 'Module.module.Float16Model.forward.0',
                          'Module.module.layer2.BasicBlock.forward.0': 'Module.module.Float16Model.forward.0',
                          'Module.module.conv.Conv2d.forward.0': 'Module.module.Float16Model.forward.0',
                          'Module.module.Float16Model.backward.0': None,
                          'Module.module.layer1.BasicBlock.backward.0': 'Module.module.Float16Model.backward.0',
                          'Module.module.layer2.BasicBlock.backward.0': 'Module.module.conv.Conv2d.backward.0',
                          'Module.module.conv.Conv2d.backward.0': 'Module.module.Float16Model.backward.0'
                          }
        up_node_id = GraphBuilder._handle_backward_inplace(construct_dict,
                                                           'Module.module.layer2.BasicBlock.backward.0',
                                                           'Module.module.conv.Conv2d.backward.0')
        self.assertEqual(up_node_id, 'Module.module.Float16Model.backward.0')
        up_node_id = GraphBuilder._handle_backward_inplace(construct_dict,
                                                           'Module.module.layer1.BasicBlock.backward.0',
                                                           'Module.module.Float16Model.backward.0')
        self.assertEqual(up_node_id, 'Module.module.Float16Model.backward.0')

    def test_is_valid_batch_p2p_output(self):
        self.assertFalse(GraphBuilder._is_valid_batch_p2p_output('a'))
        self.assertFalse(GraphBuilder._is_valid_batch_p2p_output([]))
        self.assertTrue(GraphBuilder._is_valid_batch_p2p_output([['a']]))

    def test_extract_batch_p2p_info(self):
        node_data = {
            "output": [[{'a': 1}], [{'b': 1}]]
        }
        GraphBuilder._extract_batch_p2p_info(self.graph.root, node_data)
        self.assertEqual(self.graph.root.batch_p2p_info, [{'group_id': None, 'op': None, 'peer': None}])

    def test_is_recompute_by_stack_torch(self):
        stack_list = [
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1829, "
            "in inner, \n result = forward_call(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/mova/diffusion/pipelines/mova_train.py, line 1105, "
            "in _fn, \n return module(*inputs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1555, "
            "in recompute_fn, \n fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1124, "
            "in _run_fn_with_dynamo_disabled, \n return fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py, line 1044, "
            "in _fn, \n return fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/_compile.py, line 53, in inner, "
            "\n return disable_fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1154, "
            "in unpack_hook, \n _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1147, "
            "in unpack_hook, \n args = ctx.get_args(ctx.saved_tensors)"
        ]
        self.assertTrue(GraphBuilder._is_recompute_by_stack_torch(stack_list))
        stack_list1 = [
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1840, "
            "in inner, \n hook_result = hook(self, args, kwargs, result)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/mova/diffusion/models/wan_video_dit.py, line 242, "
            "in forward, \n v = self.v(ctx)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1829, "
            "in inner, \n result = forward_call(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1555, "
            "in recompute_fn, \n fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1124, "
            "in _run_fn_with_dynamo_disabled, \n return fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py, line 1044, "
            "in _fn, \n return fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/_compile.py, line 53, in inner, "
            "\n return disable_fn(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/utils/checkpoint.py, line 1154, "
            "in unpack_hook, \n _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)"
        ]
        self.assertTrue(GraphBuilder._is_recompute_by_stack_torch(stack_list1))
        stack_list2 = [
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1840, "
            "in inner, \n hook_result = hook(self, args, kwargs, result)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/.local/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py, "
            "line 599, in forward, \n x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1829, "
            "in inner, \n result = forward_call(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/.local/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py, "
            "line 1142, in _encode, \n out_ = self.encoder(",
            "File /root/.local/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py, "
            "line 1173, in encode, \n h = self._encode(x)",
            "File /root/.local/lib/python3.11/site-packages/diffusers/utils/accelerate_utils.py, line 46, in wrapper, "
            "\n return method(self, *args, **kwargs)",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/mova/diffusion/pipelines/mova_train.py, line 1344, "
            "in training_step, \n video_latents = self.video_vae.encode(video).latent_dist.mode()",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/mova/diffusion/pipelines/mova_train.py, line 1279, "
            "in forward, \n return self.training_step(*args, cp_mesh=cp_mesh, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1829, "
            "in inner, \n result = forward_call(*args, **kwargs)",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1881, "
            "in _call_impl, \n return inner()",
            "File /root/.local/conda/envs/mova/lib/python3.11/site-packages/torch/nn/modules/module.py, line 1775, "
            "in _wrapped_call_impl, \n return self._call_impl(*args, **kwargs)",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/mova/engine/trainer/accelerate/accelerate_trainer.py"
            ", line 414, in train, \n loss_dict = self.model(",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/scripts/training_scripts/accelerate_train.py, "
            "line 180, in main, \n trainer.train()",
            "File /root/work/filestorage/gh/code/MOVA-feat-npu-dai/scripts/training_scripts/accelerate_train.py, "
            "line 184, in <module>, \n main()"
        ]
        self.assertFalse(GraphBuilder._is_recompute_by_stack_torch(stack_list2))