import os
import re
import signal
import unittest
import shutil
import argparse
from dataclasses import dataclass

from unittest.mock import patch
from msprobe.visualization.graph_service import (
    _compare_graph_result,
    _build_graph_result,
    _compare_graph_ranks,
    _compare_graph_steps,
    _build_graph_ranks,
    _build_graph_steps,
    _graph_service_command,
    _graph_service_parser,
)
from msprobe.core.common.utils import CompareException


@dataclass
class Args:
    target_path: str = None
    golden_path: str = None
    output_path: str = None
    layer_mapping: str = None
    overflow_check: bool = False
    fuzzy_match: bool = False
    is_print_compare_log: bool = True
    parallel_merge: bool = False
    parallel_params: tuple = None
    is_print_progress_log: bool = False
    file_type: str = 'db'

    rank_size: list = None
    tp: list = None
    pp: list = None
    vpp: list = (1,)
    order: list = ('tp-cp-ep-dp-pp',)


class TestGraphService(unittest.TestCase):
    def setUp(self):
        self.current_path = os.path.dirname(os.path.realpath(__file__))
        self.input = os.path.join(self.current_path, "input_format_correct")
        self.output = os.path.join(self.current_path, 'output')
        self.input_param = {
            'npu_path': os.path.join(self.input, 'step0', 'rank0'),
            'bench_path': os.path.join(self.input, 'step0', 'rank0'),
            'is_print_compare_log': True,
        }
        self.layer_mapping = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'layer_mapping.yaml')
        self.pattern = r'\b\w+\.vis.db\b'
        self.pattern_rank = r'[\w_]+\.vis.db\b'
        self.output_json = []
        for i in range(7):
            self.output_json.append(os.path.join(self.current_path, f"compare{i}.json"))
        self.original_sigpipe = signal.getsignal(signal.SIGPIPE)
        signal.signal(signal.SIGPIPE, signal.SIG_IGN)

    def assert_log_info(
        self, mock_log_info, log_info='Model graphs compared successfully, the result file is saved in'
    ):
        last_call_args = mock_log_info.call_args[0][0]
        self.assertIn(log_info, last_call_args)
        matches = re.findall(self.pattern, last_call_args)
        if matches:
            self.assertTrue(os.path.exists(os.path.join(self.output, matches[0])))

    @patch('msprobe.core.common.log.logger.info')
    def test_compare_graph_result(self, mock_log_info):
        args = Args(output_path=self.output)
        result = _compare_graph_result(self.input_param, args)
        self.assertEqual(mock_log_info.call_count, 2)
        self.assertIsNotNone(result)

        args = Args(output_path=self.output)
        result = _compare_graph_result(self.input_param, args)
        self.assertIsNotNone(result)

        args = Args(output_path=self.output, layer_mapping=self.layer_mapping)
        result = _compare_graph_result(self.input_param, args)
        self.assertIsNotNone(result)

        args = Args(output_path=self.output, overflow_check=True)
        result = _compare_graph_result(self.input_param, args)
        self.assertIsNotNone(result)

    @patch('msprobe.core.common.log.logger.info')
    def test_build_graph_result(self, mock_log_info):
        result = _build_graph_result(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True))
        self.assertEqual(mock_log_info.call_count, 1)
        self.assertIsNotNone(result)

    @patch('msprobe.core.common.log.logger.info')
    def test_compare_graph_ranks(self, mock_log_info):
        input_param = {
            'npu_path': os.path.join(self.input, 'step0'),
            'bench_path': os.path.join(self.input, 'step0'),
            'is_print_compare_log': True,
        }
        args = Args(output_path=self.output)
        _compare_graph_ranks(input_param, args)
        self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.')

    @patch('msprobe.core.common.log.logger.info')
    def test_compare_graph_steps(self, mock_log_info):
        input_param = {'npu_path': self.input, 'bench_path': self.input, 'is_print_compare_log': True}
        args = Args(output_path=self.output)
        _compare_graph_steps(input_param, args)
        self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.')

        input_param1 = {
            'npu_path': self.input,
            'bench_path': os.path.join(self.current_path, "input"),
            'is_print_compare_log': True,
        }
        args = Args(output_path=self.output)
        with self.assertRaises(CompareException):
            _compare_graph_steps(input_param1, args)

    @patch('msprobe.core.common.log.logger.info')
    def test_build_graph_ranks(self, mock_log_info):
        _build_graph_ranks(Args(target_path=self.input, output_path=self.output), 'step0')
        self.assert_log_info(mock_log_info, "Successfully exported build graph results.")

    @patch('msprobe.core.common.log.logger.info')
    def test_build_graph_steps(self, mock_log_info):
        _build_graph_steps(Args(target_path=self.input, output_path=self.output))
        self.assert_log_info(mock_log_info, "Successfully exported build graph results.")

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command(self, mock_log_info):
        args = Args(
            target_path=self.input_param.get('npu_path'),
            golden_path=self.input_param.get('bench_path'),
            output_path=self.output,
        )
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, 'Adding index to db file completed.')

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command1(self, mock_log_info):
        args = Args(target_path=os.path.join(self.input, 'step0', 'rank0'), output_path=self.output)
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, "Adding index to db file completed.")

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command2(self, mock_log_info):
        args = Args(
            target_path=os.path.join(self.input, 'step0'),
            golden_path=os.path.join(self.input, 'step0'),
            output_path=self.output,
        )
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, 'Adding index to db file completed.')

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command3(self, mock_log_info):
        args = Args(target_path=self.input, golden_path=self.input, output_path=self.output)
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, 'Adding index to db file completed.')

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command4(self, mock_log_info):
        args = Args(target_path=os.path.join(self.input, 'step0'), output_path=self.output)
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, "Adding index to db file completed.")

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command5(self, mock_log_info):
        args = Args(target_path=self.input, output_path=self.output)
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, "Adding index to db file completed.")

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command6(self, mock_log_info):
        args = Args(target_path=self.input, golden_path=os.path.join(self.input, 'step0'), output_path=self.output)
        with self.assertRaises(ValueError):
            _graph_service_command(args)

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command7(self, mock_log_info):
        args = Args(
            target_path=os.path.join(self.input, 'step0'),
            golden_path=os.path.join(self.input, 'step0'),
            output_path=self.output,
            rank_size=[2, 2],
            tp=[2, 2],
            pp=[1, 1],
        )
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, 'Adding index to db file completed.')

    @patch('msprobe.core.common.log.logger.info')
    def test_graph_service_command8(self, mock_log_info):
        args = Args(
            target_path=os.path.join(self.input, 'step0'),
            golden_path=os.path.join(self.input, 'step1'),
            output_path=self.output,
        )
        _graph_service_command(args)
        self.assert_log_info(mock_log_info, 'Adding index to db file completed.')

    def test_graph_service_parser(self):
        parser = argparse.ArgumentParser()
        _graph_service_parser(parser)
        args = parser.parse_args(['-tp', 'input.json', '-o', 'output.json'])
        self.assertEqual(args.target_path, 'input.json')
        self.assertEqual(args.output_path, 'output.json')

        args = parser.parse_args(['-tp', 'input.json', '-o', 'output.json', '-lm', 'mapping.json'])
        self.assertEqual(args.layer_mapping, 'mapping.json')

        args = parser.parse_args(['-tp', 'input.json', '-o', 'output.json', '-oc'])
        self.assertTrue(args.overflow_check)

        args = parser.parse_args(['-tp', 'input.json', '-o', 'output.json'])
        self.assertFalse(args.overflow_check)

    def tearDown(self):
        signal.signal(signal.SIGPIPE, self.original_sigpipe)
        if os.path.exists(self.output):
            shutil.rmtree(self.output)
        for json_data in self.output_json:
            if os.path.exists(json_data):
                os.remove(json_data)


if __name__ == '__main__':
    unittest.main()