import unittest
from unittest.mock import patch
from msprobe.visualization.graph.distributed_analyzer import (
CommunicationType, DistributedType, CANNOT_MATCH, DistributedAnalyzer
)
from msprobe.visualization.utils import GraphConst, Const
from msprobe.core.common.log import logger
from msprobe.visualization.graph.graph import Graph, BaseNode, NodeOp
class TestDistributedAnalyzer(unittest.TestCase):
@staticmethod
def _create_base_node(node_id, up_node=None, op=NodeOp.module):
"""创建BaseNode实例的工厂方法"""
node = BaseNode(op, node_id, up_node)
node.data = {
GraphConst.JSON_INDEX_KEY: 100,
GraphConst.OVERFLOW_LEVEL: 0.5
}
node.matched_distributed = None
return node
@staticmethod
def _create_graph(rank, nodes=None):
"""创建Graph实例的工厂方法"""
graph = Graph(model_name=f"model_rank_{rank}")
graph.rank = rank
if nodes:
for node_id, node in nodes.items():
graph.node_map[node_id] = node
return graph
def setUp(self):
"""初始化测试数据,每个测试方法执行前都会调用"""
self.base_node_factory = lambda node_id, up_node=None, op=NodeOp.module: self._create_base_node(node_id,
up_node, op)
self.graph_factory = lambda rank, nodes=None: self._create_graph(rank, nodes)
self.distributed_info = self._build_test_distributed_info()
self.analyzer = DistributedAnalyzer(self.distributed_info, overflow_check=False)
self.analyzer_overflow = DistributedAnalyzer(self.distributed_info, overflow_check=True)
def _build_test_distributed_info(self):
"""构建多rank的测试图字典"""
isend_node = self.base_node_factory("Distributed.isend.0.forward")
isend_node.input_data = {
f"{isend_node.id}{GraphConst.INPUT}dst": {"value": 1}
}
isend_node.output_data = {
f"{isend_node.id}.output.0": {
Const.DTYPE: "float32",
Const.SHAPE: [1, 2],
Const.MAX: 1.0,
Const.MIN: 0.0,
Const.MEAN: 0.5,
Const.NORM: 1.0
}
}
recv_node = self.base_node_factory("Distributed.recv.0.forward")
recv_node.input_data = {
f"{recv_node.id}{GraphConst.INPUT}src": {"value": 1}
}
recv_node.output_data = {
f"{recv_node.id}.output.0": {
Const.DTYPE: "float32",
Const.SHAPE: [1, 2],
Const.MAX: 1.0,
Const.MIN: 0.0,
Const.MEAN: 0.5,
Const.NORM: 1.0
}
}
broadcast_node = self.base_node_factory("Distributed.broadcast.0.forward")
broadcast_node.input_data = {
f"{broadcast_node.id}{GraphConst.INPUT}1": {"value": 0},
f"{broadcast_node.id}{GraphConst.INPUT}group": {
"group_ranks": [0, 1],
"group_id": "group_0"
}
}
batch_p2p_node = self.base_node_factory("Distributed.batch_p2p.0.forward")
batch_p2p_node.batch_p2p_info = [
{GraphConst.OP: "isend", GraphConst.PEER: 1, GraphConst.GROUP_ID: "g1"},
{GraphConst.OP: "recv", GraphConst.PEER: 2, GraphConst.GROUP_ID: "g2"}
]
normal_node = self.base_node_factory("Module.forward.0", op=NodeOp.module)
irecv_node = self.base_node_factory("Distributed.irecv.0.forward")
irecv_node.input_data = {
f"{irecv_node.id}{GraphConst.INPUT}src": {"value": 0}
}
irecv_node.output_data = {
f"{irecv_node.id}.output.0": {
Const.DTYPE: "float32",
Const.SHAPE: [1, 2],
Const.MAX: 1.0,
Const.MIN: 0.0,
Const.MEAN: 0.5,
Const.NORM: 1.0
}
}
send_node = self.base_node_factory("Distributed.send.0.forward")
send_node.input_data = {
f"{send_node.id}{GraphConst.INPUT}dst": {"value": 0}
}
broadcast_node_1 = self.base_node_factory("Distributed.broadcast.0.forward")
broadcast_node_1.input_data = {
f"{broadcast_node_1.id}{GraphConst.INPUT}1": {"value": 0},
f"{broadcast_node_1.id}{GraphConst.INPUT}group": {
"group_ranks": [0, 1],
"group_id": "group_0"
}
}
return {
"0": {
isend_node.id: isend_node,
recv_node.id: recv_node,
broadcast_node.id: broadcast_node,
batch_p2p_node.id: batch_p2p_node,
normal_node.id: normal_node
},
"1": {
irecv_node.id: irecv_node,
send_node.id: send_node,
broadcast_node_1.id: broadcast_node_1
}
}
def test_enum_values(self):
"""验证枚举值正确性"""
self.assertEqual(CommunicationType.SEND.value, "send")
self.assertEqual(CommunicationType.RECEIVE.value, "receive")
self.assertEqual(CommunicationType.SEND_RECEIVE.value, "send_receive")
self.assertEqual(DistributedType.P2P.value, "p2p")
self.assertEqual(DistributedType.COLLECTIVE.value, "collective")
def test_get_opposite_communication_type(self):
"""测试通信类型反转逻辑"""
self.assertEqual(DistributedAnalyzer._get_opposite_communication_type("send"), "receive")
self.assertEqual(DistributedAnalyzer._get_opposite_communication_type("receive"), "send")
self.assertEqual(DistributedAnalyzer._get_opposite_communication_type("send_receive"), "send_receive")
self.assertEqual(DistributedAnalyzer._get_opposite_communication_type("unknown"), "unknown")
def test__node_output_all_equal(self):
"""测试节点输出数据对比"""
data1 = {
Const.DTYPE: "float32",
Const.SHAPE: [1, 2],
Const.MAX: 1.0,
Const.MIN: 0.0,
Const.MEAN: 0.5,
Const.NORM: 1.0
}
data2 = data1.copy()
self.assertTrue(DistributedAnalyzer._node_output_all_equal(data1, data2))
data3 = data1.copy()
data3[Const.MAX] = 2.0
self.assertFalse(DistributedAnalyzer._node_output_all_equal(data1, data3))
self.assertFalse(DistributedAnalyzer._node_output_all_equal(data1, {}))
self.assertFalse(DistributedAnalyzer._node_output_all_equal({}, data2))
def test_get_target_rank(self):
"""测试获取目标rank"""
node = self.base_node_factory("test_node")
node_id = node.id
node.input_data = {f"{node_id}{GraphConst.INPUT}dst": {"value": 1}}
self.assertEqual(DistributedAnalyzer._get_target_rank(node, 0, "dst"), 1)
node.input_data = {}
with patch.object(logger, "debug") as mock_log:
self.assertIsNone(DistributedAnalyzer._get_target_rank(node, 0, "dst"))
mock_log.assert_called_with(
f'The parameter dst of node {node_id} does not exist, {CANNOT_MATCH}0'
)
def test_get_group_info(self):
"""测试获取group信息"""
node = self.base_node_factory("test_node")
node_id = node.id
node.input_data = {
f"{node_id}{GraphConst.INPUT}group": {
"group_ranks": [0, 1],
"group_id": "g1"
}
}
self.assertEqual(DistributedAnalyzer._get_group_info(node, 0), ([0, 1], "g1"))
node.input_data = {}
with patch.object(logger, "debug") as mock_log:
self.assertEqual(DistributedAnalyzer._get_group_info(node, 0), (None, None))
mock_log.assert_called_with(
f'The kwarg group of node {node_id} does not exist, {CANNOT_MATCH}0'
)
node.input_data = {
f"{node_id}{GraphConst.INPUT}group": {"group_id": "g1"}
}
with patch.object(logger, "debug") as mock_log:
self.assertEqual(DistributedAnalyzer._get_group_info(node, 0), (None, None))
mock_log.assert_called_with(
f'The group_ranks of node {node_id} does not exist, {CANNOT_MATCH}0'
)
node.input_data = {
f"{node_id}{GraphConst.INPUT}group": {"group_ranks": [0, 1]}
}
with patch.object(logger, "debug") as mock_log:
self.assertEqual(DistributedAnalyzer._get_group_info(node, 0), (None, None))
mock_log.assert_called_with(
f'The group_id of node {node_id} does not exist, {CANNOT_MATCH}0'
)
def test_init(self):
"""测试初始化逻辑"""
self.assertEqual(self.analyzer.overflow_check, False)
self.assertIn("send", self.analyzer.config)
self.assertIn("broadcast", self.analyzer.config)
self.assertNotEqual(self.analyzer.group_node_mapping, {})
self.assertEqual(self.analyzer_overflow.overflow_check, True)
def test_make_group_node_mapping(self):
"""测试构建group节点映射"""
self.analyzer.group_node_mapping = {}
self.analyzer._make_group_node_mapping()
rank0_mapping = self.analyzer.group_node_mapping.get("0")
self.assertIsNotNone(rank0_mapping)
isend_node_id = "Distributed.isend.0.forward"
unique_group_id = rank0_mapping.get(isend_node_id)
self.assertIsNotNone(unique_group_id)
self.assertIn("isendrank1", unique_group_id)
self.assertEqual(rank0_mapping.get(unique_group_id), isend_node_id)
broadcast_node_id = "Distributed.broadcast.0.forward"
broadcast_unique_id = rank0_mapping.get(broadcast_node_id)
self.assertIsNotNone(broadcast_unique_id)
self.assertIn("group_0broadcast", broadcast_unique_id)
def test_make_batch_p2p_mapping(self):
"""测试批量P2P映射构建"""
batch_p2p_node = self.distributed_info["0"]["Distributed.batch_p2p.0.forward"]
batch_p2p_count = {}
self.analyzer.group_node_mapping = {0: {}}
self.analyzer._make_batch_p2p_mapping(batch_p2p_node, 0, batch_p2p_count)
rank0_mapping = self.analyzer.group_node_mapping[0]
self.assertIn("isend_rank1_g1_1", rank0_mapping)
self.assertIn("recv_rank2_g2_1", rank0_mapping)
self.assertEqual(rank0_mapping[batch_p2p_node.id], ["isend_rank1_g1_1", "recv_rank2_g2_1"])
self.assertEqual(rank0_mapping["isend_rank1_g1_1"], batch_p2p_node.id)
invalid_batch_node = self.base_node_factory("Distributed.batch_p2p.1.forward")
invalid_batch_node.batch_p2p_info = [
{GraphConst.OP: None, GraphConst.PEER: 1},
{GraphConst.OP: "recv", GraphConst.PEER: None}
]
with patch.object(logger, "debug") as mock_log:
self.analyzer._make_batch_p2p_mapping(invalid_batch_node, 0, batch_p2p_count)
mock_log.assert_called_with('Cannot get param op or peer.')
def test_get_distributed_name_and_type(self):
"""测试解析分布式节点名称和类型"""
node_id = "Distributed.isend.0.forward"
api_name, dist_type = self.analyzer._get_distributed_name_and_type(node_id)
self.assertEqual(api_name, "isend")
self.assertEqual(dist_type, DistributedType.P2P)
node_id = "Distributed.broadcast.0.forward"
api_name, dist_type = self.analyzer._get_distributed_name_and_type(node_id)
self.assertEqual(api_name, "broadcast")
self.assertEqual(dist_type, DistributedType.COLLECTIVE)
node_id = "Distributed.unknown_api.0.forward"
api_name, dist_type = self.analyzer._get_distributed_name_and_type(node_id)
self.assertEqual(api_name, "unknown_api")
self.assertEqual(dist_type, DistributedType.COLLECTIVE)
with self.assertRaises(ValueError) as excinfo:
self.analyzer._get_distributed_name_and_type("invalid_node_id")
self.assertIn("Invalid node id invalid_node_id.", str(excinfo.exception))
def test_get_target_node(self):
"""测试获取目标节点"""
rank0_isend_unique_id = self.analyzer.group_node_mapping["0"]["Distributed.isend.0.forward"]
target_node = self.analyzer._get_target_node(
rank=0,
unique_group_id=rank0_isend_unique_id,
api_name="isend",
target_rank=1,
target_api_name="irecv"
)
self.assertIsNotNone(target_node)
self.assertEqual(target_node.id, "Distributed.irecv.0.forward")
with patch.object(logger, "debug") as mock_log:
target_node = self.analyzer._get_target_node(
rank=0,
unique_group_id="test",
api_name="isend",
target_rank=999,
target_api_name="irecv"
)
self.assertIsNone(target_node)
mock_log.assert_called_with(f'Node data does not exist, {CANNOT_MATCH}999')
with patch.object(logger, "debug") as mock_log:
target_node = self.analyzer._get_target_node(
rank=0,
unique_group_id="invalid_id",
api_name="isend",
target_rank=1,
target_api_name="irecv"
)
self.assertIsNone(target_node)
mock_log.assert_called_with(f'Node does not exist, {CANNOT_MATCH}1')
def test_add_node_matched_distributed(self):
"""测试添加节点匹配信息"""
source_node = self.base_node_factory("source_node")
target_node = self.base_node_factory("target_node")
target_node.data = {GraphConst.JSON_INDEX_KEY: 200}
self.analyzer._add_node_matched_distributed(source_node, target_node, "isend", 1)
self.assertEqual(source_node.matched_distributed, {
"communications_type": "send",
"nodes_info": {1: ["200", "target_node"]}
})
self.analyzer._add_node_matched_distributed(source_node, target_node, "broadcast", 1, reversal_type=True)
self.assertEqual(source_node.matched_distributed["communications_type"], "receive")
def test_p2p_match(self):
"""测试P2P节点匹配"""
isend_node = self.distributed_info["0"]["Distributed.isend.0.forward"]
self.analyzer._p2p_match(isend_node, 0, "isend")
self.assertIsNotNone(isend_node.matched_distributed)
self.assertIsNotNone(self.distributed_info["1"]["Distributed.irecv.0.forward"].matched_distributed)
isend_node.matched_distributed = {}
isend_node.input_data[f"{isend_node.id}{GraphConst.INPUT}dst"]["value"] = 999
self.analyzer._p2p_match(isend_node, 0, "isend")
self.assertEqual(isend_node.matched_distributed, {})
isend_node.input_data[f"{isend_node.id}{GraphConst.INPUT}dst"]["value"] = 1
irecv_node = self.distributed_info["1"]["Distributed.irecv.0.forward"]
irecv_node.input_data[f"{irecv_node.id}{GraphConst.INPUT}src"]["value"] = 2
with patch.object(logger, "debug") as mock_log:
self.analyzer._p2p_match(isend_node, 0, "isend")
self.assertIn("rank is inconsistent", mock_log.call_args[0][0])
def test_collective_match(self):
"""测试集体通信节点匹配"""
broadcast_node = self.distributed_info["0"]["Distributed.broadcast.0.forward"]
self.analyzer._collective_match(broadcast_node, 0, "broadcast")
self.assertIsNotNone(broadcast_node.matched_distributed)
self.assertIsNotNone(self.distributed_info["1"]["Distributed.broadcast.0.forward"].matched_distributed)
broadcast_node.matched_distributed = {}
broadcast_node.input_data[f"{broadcast_node.id}{GraphConst.INPUT}1"]["value"] = 1
self.analyzer._collective_match(broadcast_node, 0, "broadcast")
self.assertEqual(broadcast_node.matched_distributed, {})
broadcast_node.input_data[f"{broadcast_node.id}{GraphConst.INPUT}1"]["value"] = 0
broadcast_node.input_data.pop(f"{broadcast_node.id}{GraphConst.INPUT}group")
self.analyzer._collective_match(broadcast_node, 0, "broadcast")
self.assertEqual(broadcast_node.matched_distributed, {})
broadcast_node.input_data[f"{broadcast_node.id}{GraphConst.INPUT}group"] = {
"group_ranks": [0, 1],
"group_id": "g2"
}
with patch.object(logger, "debug") as mock_log:
self.analyzer._collective_match(broadcast_node, 0, "broadcast")
self.assertIn("group id of the two nodes are different", mock_log.call_args[0][0])
def test_edge_cases(self):
"""测试边界场景"""
empty_graphs = {}
analyzer = DistributedAnalyzer(empty_graphs, overflow_check=False)
analyzer.distributed_match()
single_rank_node = self.base_node_factory("Distributed.broadcast.0.forward")
single_rank_node.input_data = {
f"{single_rank_node.id}{GraphConst.INPUT}group": {
"group_ranks": [0],
"group_id": "g1"
}
}
analyzer = DistributedAnalyzer({0: {single_rank_node.id: single_rank_node}}, overflow_check=True)
analyzer.distributed_match()
analyzer = DistributedAnalyzer({0: {}}, overflow_check=False)
analyzer.distributed_match()
unknown_node = self.base_node_factory("Distributed.unknown_api.0.forward")
analyzer = DistributedAnalyzer({0: {unknown_node.id: unknown_node}}, overflow_check=False)
analyzer.distributed_match()