from unittest.mock import patch, MagicMock
import torch
import torch.distributed as torch_dist
from mindie_llm.runtime.utils.distributed.communication_op import all_gather, gather_tensor, allgather_and_reorder
def test_gather_tensor_none_index():
"""测试gather_tensor边界场景:入参index=None"""
input_tensor = torch.randn(3, 5)
output = gather_tensor(input_tensor, index=None)
assert torch.equal(output, input_tensor)
def test_gather_tensor_normal():
"""测试gather_tensor正常场景:索引重排逻辑正确性"""
input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
gather_index = torch.tensor([3, 0, 1, 2])
output = gather_tensor(input_tensor, gather_index)
expected = torch.tensor([[7, 8], [1, 2], [3, 4], [5, 6]])
assert torch.equal(output, expected)
def test_all_gather_basic():
"""测试all_gather核心逻辑"""
with patch.dict('sys.modules', {'torch_npu': MagicMock()}):
mock_process_group = MagicMock()
mock_process_group.size.return_value = 2
with patch.object(torch_dist, "all_gather_into_tensor"):
input_tensor = torch.tensor([[10], [20]])
output = all_gather(input_tensor, mock_process_group)
assert output.shape == (4, 1)
torch_dist.all_gather_into_tensor.assert_called_once()
mock_process_group.size.assert_called_once()
def test_allgather_and_reorder_full():
"""完整测试allgather_and_reorder组合逻辑"""
with patch.dict('sys.modules', {'torch_npu': MagicMock()}):
mock_process_group = MagicMock()
mock_process_group.size.return_value = 4
with patch.object(torch_dist, "all_gather_into_tensor"):
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
gather_index = torch.tensor([3, 0, 1, 2])
output = allgather_and_reorder(input_tensor, mock_process_group, gather_index)
assert output.shape == (4, 3)
torch_dist.all_gather_into_tensor.assert_called_once()
mock_process_group.size.assert_called_once()
if __name__ == "__main__":
test_gather_tensor_none_index()
test_gather_tensor_normal()
test_all_gather_basic()
test_allgather_and_reorder_full()