import unittest
from unittest.mock import MagicMock
from msprobe.visualization.graph.graph import Graph, NodeOp
from msprobe.visualization.graph.base_node import BaseNode
from msprobe.visualization.utils import GraphConst
class TestGraph(unittest.TestCase):
def setUp(self):
self.graph = Graph("model_name")
self.node_id = "node_id"
self.node_op = NodeOp.module
def test_add_node_and_get_node(self):
self.graph.add_node(self.node_op, self.node_id)
node = self.graph.get_node(self.node_id)
self.assertIsNotNone(node)
self.assertIn(self.node_id, self.graph.node_map)
node_id = "api"
graph = Graph("model_name")
for i in range(0, 9):
graph.add_node(NodeOp.function_api, node_id, id_accumulation=True)
self.assertEqual(len(graph.node_map), 10)
self.assertIn("api.0", graph.node_map)
self.assertIn("api.8", graph.node_map)
self.assertNotIn("api", graph.node_map)
def test_str(self):
self.graph.add_node(self.node_op, self.node_id)
expected_str = f'{self.node_id}'
self.assertIn(expected_str, str(self.graph))
def test_match(self):
graph_a = Graph("model_name_a")
graph_b = Graph("model_name_b")
node_a = BaseNode(self.node_op, self.node_id)
graph_a.add_node(NodeOp.module, "node_id_a")
graph_b.add_node(NodeOp.module, "node_id_b")
matched_node, ancestors = Graph.match(graph_a, node_a, graph_b)
self.assertIsNone(matched_node)
self.assertEqual(ancestors, [])
graph_b.add_node(NodeOp.module, "node_id_a")
graph_a.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a"))
graph_b.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a"))
matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b)
self.assertIsNotNone(matched_node)
self.assertEqual(ancestors, ['node_id_a'])
def test_split_nodes_by_micro_step(self):
nodes = [BaseNode(NodeOp.module, 'a.forward.0'), BaseNode(NodeOp.module, 'a.backward.0'),
BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.forward.1'),
BaseNode(NodeOp.module, 'b.forward.0'), BaseNode(NodeOp.module, 'b.backward.0'),
BaseNode(NodeOp.module, 'a.backward.1'), BaseNode(NodeOp.api_collection, 'apis.1')]
result = Graph.split_nodes_by_micro_step(nodes)
self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]), 3)
def test_paging_by_micro_step(self):
nodes = [BaseNode(NodeOp.module, 'a.forward.0'), BaseNode(NodeOp.module, 'a.backward.0'),
BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.forward.1'),
BaseNode(NodeOp.module, 'b.forward.0'), BaseNode(NodeOp.module, 'b.backward.0'),
BaseNode(NodeOp.module, 'a.backward.1'), BaseNode(NodeOp.api_collection, 'apis.1')]
graph = Graph('Model1')
graph.root.subnodes = nodes
graph_other = Graph('Model2')
graph_other.root.subnodes = nodes
result = graph.paging_by_micro_step(graph_other)
self.assertEqual(result, 2)
self.assertEqual(graph.root.subnodes[0].micro_step_id, 0)
self.assertEqual(graph_other.root.subnodes[0].micro_step_id, 0)
def test_mapping_match(self):
graph_a = Graph("model_name_a")
graph_b = Graph("model_name_b")
graph_a.add_node(NodeOp.module, "a1", BaseNode(NodeOp.module, "root"))
graph_b.add_node(NodeOp.module, "b1", BaseNode(NodeOp.module, "root"))
mapping_dict = {"a1": "b1"}
node_b, ancestors_n, ancestors_b = Graph.mapping_match(graph_a.get_node("a1"), graph_b, mapping_dict)
self.assertIsNotNone(node_b)
self.assertEqual(ancestors_n, ["root"])
self.assertEqual(ancestors_b, ["root"])