import os
import unittest
from typing import Any
from dataclasses import dataclass
from unittest.mock import patch
from unittest.mock import MagicMock
from msprobe.visualization.compare.graph_comparator import GraphComparator
from msprobe.visualization.graph.graph import Graph, BaseNode, NodeOp
from msprobe.visualization.utils import GraphConst
@dataclass
class Args:
input_path: str = None
output_path: str = None
layer_mapping: Any = None
framework: str = None
overflow_check: bool = False
fuzzy_match: bool = False
class TestGraphComparator(unittest.TestCase):
def setUp(self):
self.current_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
self.input = os.path.join(self.current_path, "input_format_correct")
self.output = os.path.join(self.current_path, 'output')
self.dump_path_param = {
'npu_path': os.path.join(self.input, 'step0', 'rank0', 'dump.json'),
'bench_path': os.path.join(self.input, 'step0', 'rank0', 'dump.json'),
'stack_json_path': os.path.join(self.input, 'step0', 'rank0', 'stack.json'),
'is_print_compare_log': True
}
self.graphs = [Graph("model1"), Graph("model2")]
self.output_path = "output/output.vis"
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test_compare(self, mock_get_compare_mode):
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
comparator._compare_nodes = MagicMock()
comparator._postcompare = MagicMock()
comparator.compare()
comparator._compare_nodes.assert_called_once()
comparator._postcompare.assert_called_once()
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test_add_compare_result_to_node(self, mock_get_compare_mode):
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
node = MagicMock()
compare_result_list = [("output1", "data1"), ("input1", "data2")]
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
comparator.ma = MagicMock()
comparator.ma.prepare_real_data.return_value = True
comparator.add_compare_result_to_node(node, compare_result_list)
comparator.ma.prepare_real_data.assert_called_once_with(node)
node.data.update.assert_not_called()
@patch('msprobe.visualization.graph.node_colors.NodeColors.get_node_error_status')
@patch('msprobe.visualization.compare.graph_comparator.get_csv_df')
@patch('msprobe.visualization.compare.graph_comparator.run_real_data')
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test__postcompare(self, mock_get_compare_mode, mock_run_real_data, mock_get_csv_df, mock_get_node_error_status):
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
mock_df = MagicMock()
mock_df.iterrows = MagicMock(return_value=[(None, MagicMock())])
mock_run_real_data.return_value = mock_df
mock_get_csv_df.return_value = mock_df
mock_get_node_error_status.return_value = True
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
comparator.ma = MagicMock()
comparator.ma.compare_mode = GraphConst.REAL_DATA_COMPARE
comparator._handle_api_collection_index = MagicMock()
comparator.ma.compare_nodes = [MagicMock()]
comparator.ma.parse_result = MagicMock(return_value=(0.9, None))
comparator._postcompare()
comparator._handle_api_collection_index.assert_called_once()
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test__handle_api_collection_index(self, mock_get_compare_mode):
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0')
api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0')
api1.data = {GraphConst.JSON_INDEX_KEY: 0.9}
api2 = BaseNode(NodeOp.function_api, 'Tensor.b.0')
api2.data = {GraphConst.JSON_INDEX_KEY: 0.6}
apis.subnodes = [api1, api2]
sub_nodes = [BaseNode(NodeOp.module, 'Module.a.0'), apis, BaseNode(NodeOp.module, 'Module.a.1')]
comparator.graph_n.root.subnodes = sub_nodes
comparator._handle_api_collection_index()
self.assertEqual(comparator.graph_n.root.subnodes[1].data.get(GraphConst.JSON_INDEX_KEY), 0.9)
@patch('msprobe.visualization.builder.msprobe_adapter.compare_node')
@patch('msprobe.visualization.graph.graph.Graph.match')
@patch('msprobe.visualization.graph.graph.Graph.mapping_match')
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test__compare_nodes(self, mock_get_compare_mode, mock_mapping_match, mock_match, mock_compare_node):
node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0')
node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0')
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
mock_mapping_match.return_value = (node_b, [], [])
mock_compare_node.return_value = ['result']
comparator = GraphComparator(self.graphs, self.dump_path_param,
Args(output_path=self.output_path, layer_mapping=True), True)
comparator.mapping_dict = True
comparator._compare_nodes(node_n)
self.assertEqual(node_n.matched_node_link, ['Tensor.b.0'])
self.assertEqual(node_b.matched_node_link, ['Tensor.a.0'])
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
comparator.mapping_dict = False
node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0')
node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0')
mock_match.return_value = (node_b, [])
comparator._compare_nodes(node_n)
self.assertEqual(node_n.matched_node_link, ['Tensor.a.0'])
self.assertEqual(node_b.matched_node_link, ['Tensor.a.0'])
@patch('msprobe.visualization.builder.msprobe_adapter.compare_node')
@patch('msprobe.visualization.graph.graph.Graph.match')
@patch('msprobe.visualization.graph.graph.Graph.fuzzy_match')
@patch('msprobe.visualization.compare.graph_comparator.get_compare_mode')
def test_compare_nodes_fuzzy(self, mock_get_compare_mode, mock_fuzzy_match, mock_match, mock_compare_node):
node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0')
node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0')
node_module_n = BaseNode(NodeOp.module, 'Module.a.0')
node_module_n.subnodes = [node_n]
node_n.upnode = node_module_n
node_module_b = BaseNode(NodeOp.module, 'Module.b.0')
node_module_b.subnodes = [node_b]
node_b.upnode = node_module_b
self.graphs[0].node_map[node_n.id] = node_n
self.graphs[1].node_map[node_b.id] = node_b
self.graphs[0].node_map[node_module_n.id] = node_module_n
self.graphs[1].node_map[node_module_b.id] = node_module_b
mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE
mock_fuzzy_match.return_value = (node_module_b, [], [])
mock_compare_node.return_value = ['result']
comparator = GraphComparator(self.graphs, self.dump_path_param,
Args(output_path=self.output_path, layer_mapping=True), True)
comparator.mapping_dict = True
comparator._compare_nodes_fuzzy(node_module_n)
self.assertEqual(node_n.matched_node_link, ['Module.b.0', 'Module.b.0'])
self.assertEqual(node_b.matched_node_link, [])
def test_add_compare_result_node(self):
compare_result_list = [
['Tensor.__truediv__.139.backward.input.0', 'Tensor.__truediv__.139.backward.input.0', 'torch.float32',
'torch.float32', [], [], 'False', 'False', 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%',
0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, True, '', '', 'None'],
['Tensor.__truediv__.139.backward.output.0', 'Tensor.__truediv__.139.backward.output.0', 'torch.float32',
'torch.float32', [], [], 'False', 'False', 0.0, 0.0, 0.0, 0.0, '0.0%', '0.0%', '0.0%', '0.0%',
0.25, 0.000244140625, 0.000244140625, 0.000244140625, 0.000244140625, 0.000244140625,
0.000244140625, 0.000244140625, True, '', '', 'None']
]
node = BaseNode(NodeOp.module, 'Module.module.Float16Module.forward.0')
comparator = GraphComparator(self.graphs, self.dump_path_param, Args(output_path=self.output_path), False)
comparator.add_compare_result_to_node(node, compare_result_list)
self.assertEqual(node.data, {'precision_index': 0})