"""
msmodelslim.utils.distributed 模块的单元测试
"""
import os
import socket
import unittest
from unittest.mock import patch, MagicMock
import torch
import torch.nn as nn
import torch.distributed as dist
from msmodelslim.utils.distributed import DistHelper
from msmodelslim.utils.distributed.dist_setup import find_free_port, setup_distributed
from msmodelslim.utils.distributed.dist_ops import (
ReduceOperation,
sync_base_operation,
sync_gather_tensors
)
from msmodelslim.utils.exception import SchemaValidateError, EnvError, UnsupportedError
class TestDistHelper(unittest.TestCase):
def setUp(self):
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
self.test_model = TestModel()
self.test_model.child_module = nn.Linear(5, 2)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_init_with_mocked_distributed(self, mock_all_gather_object, mock_get_world_size):
"""测试初始化方法"""
mock_get_world_size.return_value = 2
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
gathered_modules[1] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertEqual(helper._model, self.test_model)
expected_local_modules = {
'', 'linear1', 'relu', 'dropout', 'child_module'
}
self.assertEqual(helper._local_modules, expected_local_modules)
self.assertEqual(helper._shared_modules, expected_local_modules)
self.assertEqual(helper._all_modules, expected_local_modules)
self.assertEqual(helper._local_only_modules, set())
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_local_modules_generator(self, mock_all_gather_object, mock_get_world_size):
"""测试本地模块生成器"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
local_modules = list(helper.local_modules())
self.assertEqual(len(local_modules), 5)
self.assertIn(self.test_model, local_modules)
self.assertIn(self.test_model.linear1, local_modules)
self.assertIn(self.test_model.relu, local_modules)
self.assertIn(self.test_model.dropout, local_modules)
self.assertIn(self.test_model.child_module, local_modules)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_shared_modules_generator(self, mock_all_gather_object, mock_get_world_size):
"""测试共享模块生成器"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
shared_modules = list(helper.shared_modules())
self.assertEqual(len(shared_modules), 5)
self.assertIn(self.test_model, shared_modules)
self.assertIn(self.test_model.linear1, shared_modules)
self.assertIn(self.test_model.relu, shared_modules)
self.assertIn(self.test_model.dropout, shared_modules)
self.assertIn(self.test_model.child_module, shared_modules)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_all_modules_generator(self, mock_all_gather_object, mock_get_world_size):
"""测试所有模块生成器"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
all_modules = list(helper.all_modules())
self.assertEqual(len(all_modules), 5)
self.assertIn(self.test_model, all_modules)
self.assertIn(self.test_model.linear1, all_modules)
self.assertIn(self.test_model.relu, all_modules)
self.assertIn(self.test_model.dropout, all_modules)
self.assertIn(self.test_model.child_module, all_modules)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_local_only_modules_generator(self, mock_all_gather_object, mock_get_world_size):
"""测试仅本地模块生成器"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
local_only_modules = list(helper.local_only_modules())
self.assertEqual(len(local_only_modules), 0)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
@patch('torch.distributed.get_rank')
def test_local_only_modules_generator_with_different_modules(self, mock_get_rank, mock_all_gather_object,
mock_get_world_size):
"""测试仅本地模块生成器在不同模块配置下的行为"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = {'', 'module_a', 'module_b'}
gathered_modules[1] = {'', 'module_a', 'module_c'}
return None
mock_all_gather_object.side_effect = side_effect
class LocalOnlyTestModel(nn.Module):
def __init__(self):
super().__init__()
self.module_a = nn.Linear(10, 5)
self.module_b = nn.Linear(5, 2)
local_only_model = LocalOnlyTestModel()
helper = DistHelper(local_only_model)
local_only_modules = list(helper.local_only_modules())
self.assertEqual(len(local_only_modules), 1)
self.assertIn(local_only_model.module_b, local_only_modules)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_is_local_method(self, mock_all_gather_object, mock_get_world_size):
"""测试is_local方法"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertTrue(helper.is_local('linear1'))
self.assertTrue(helper.is_local(''))
self.assertFalse(helper.is_local('non_existent_module'))
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_is_local_only_method(self, mock_all_gather_object, mock_get_world_size):
"""测试is_local_only方法"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertFalse(helper.is_local_only('linear1'))
self.assertFalse(helper.is_local_only(''))
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_is_shared_method(self, mock_all_gather_object, mock_get_world_size):
"""测试is_shared方法"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertTrue(helper.is_shared('linear1'))
self.assertTrue(helper.is_shared(''))
self.assertFalse(helper.is_shared('non_existent_module'))
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_is_all_method(self, mock_all_gather_object, mock_get_world_size):
"""测试is_all方法"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertTrue(helper.is_all('linear1'))
self.assertTrue(helper.is_all(''))
self.assertFalse(helper.is_all('non_existent_module'))
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
@patch('torch.distributed.get_rank')
def test_get_shared_modules_slice(self, mock_get_rank, mock_all_gather_object, mock_get_world_size):
"""测试get_shared_modules_slice方法"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
gathered_modules[1] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
result = helper.get_shared_modules_slice()
expected = sorted(['', 'child_module', 'dropout', 'linear1', 'relu'])[0::2]
self.assertEqual(result, expected)
result_with_prefix = helper.get_shared_modules_slice(prefix="model")
expected_with_prefix = sorted([
f"model.{name}" for name in ['', 'child_module', 'dropout', 'linear1', 'relu']
])[0::2]
self.assertEqual(result_with_prefix, expected_with_prefix)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
@patch('torch.distributed.get_rank')
def test_get_shared_modules_slice_different_rank(self, mock_get_rank, mock_all_gather_object, mock_get_world_size):
"""测试不同rank下的get_shared_modules_slice方法"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
gathered_modules[1] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
result = helper.get_shared_modules_slice()
expected = sorted(['', 'child_module', 'dropout', 'linear1', 'relu'])[1::2]
self.assertEqual(result, expected)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
@patch('torch.distributed.get_rank')
def test_get_rank_method(self, mock_get_rank, mock_all_gather_object, mock_get_world_size):
"""测试get_rank方法"""
mock_get_world_size.return_value = 1
mock_get_rank.return_value = 42
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
helper = DistHelper(self.test_model)
self.assertEqual(helper.get_rank(), 42)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather')
def test_gather_variable_shapes_static_method(self, mock_all_gather, mock_get_world_size):
"""测试gather_variable_shapes静态方法"""
mock_get_world_size.return_value = 2
mock_all_gather.return_value = None
local_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
def all_gather_side_effect(tensor_list, tensor):
tensor_list[:] = [x.clone() for x in [tensor] * len(tensor_list)]
return None
mock_all_gather.side_effect = all_gather_side_effect
result = DistHelper.gather_variable_shapes(local_tensor)
self.assertEqual(len(result), 2)
for tensor in result:
self.assertTrue(torch.equal(tensor, local_tensor))
self.assertEqual(tensor.dtype, local_tensor.dtype)
self.assertEqual(tensor.shape, local_tensor.shape)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather')
def test_gather_variable_shapes_with_different_shapes(self, mock_all_gather, mock_get_world_size):
"""测试gather_variable_shapes静态方法处理不同形状的张量"""
mock_get_world_size.return_value = 2
mock_all_gather.return_value = None
local_tensor = torch.tensor([1, 2, 3], dtype=torch.float32)
def all_gather_side_effect(tensor_list, tensor):
tensor_list[:] = [x.clone() for x in [tensor] * len(tensor_list)]
return None
mock_all_gather.side_effect = all_gather_side_effect
result = DistHelper.gather_variable_shapes(local_tensor)
self.assertEqual(len(result), 2)
for tensor in result:
self.assertTrue(torch.equal(tensor, local_tensor))
self.assertEqual(tensor.dtype, local_tensor.dtype)
self.assertEqual(tensor.shape, local_tensor.shape)
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_init_with_empty_model(self, mock_all_gather_object, mock_get_world_size):
"""测试使用空模型初始化"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
empty_model = nn.Module()
helper = DistHelper(empty_model)
self.assertEqual(helper._local_modules, {''})
self.assertEqual(helper._shared_modules, {''})
self.assertEqual(helper._all_modules, {''})
self.assertEqual(helper._local_only_modules, set())
@patch('torch.distributed.get_world_size')
@patch('torch.distributed.all_gather_object')
def test_init_with_nested_modules(self, mock_all_gather_object, mock_get_world_size):
"""测试带有嵌套模块的模型"""
mock_get_world_size.return_value = 1
mock_all_gather_object.return_value = None
def side_effect(gathered_modules, local_modules):
gathered_modules[0] = local_modules
return None
mock_all_gather_object.side_effect = side_effect
class NestedModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU()
)
self.layer2 = nn.Linear(5, 1)
nested_model = NestedModel()
helper = DistHelper(nested_model)
expected_modules = {
'', 'layer1', 'layer1.0', 'layer1.1', 'layer2'
}
self.assertEqual(helper._local_modules, expected_modules)
class TestFindFreePort(unittest.TestCase):
"""测试 find_free_port 函数"""
@patch('socket.socket')
def test_find_free_port_retries_on_oserror(self, mock_socket_class):
"""测试端口被占用时会尝试下一个端口"""
mock_socket = MagicMock()
mock_socket.__enter__ = MagicMock(return_value=mock_socket)
mock_socket.__exit__ = MagicMock(return_value=False)
mock_socket.bind.side_effect = [
OSError("Port 29500 in use"),
OSError("Port 29501 in use"),
None
]
mock_socket_class.return_value = mock_socket
port = find_free_port(start_port=29500, max_attempts=10)
self.assertEqual(port, 29502)
def test_find_free_port_success(self):
"""测试成功找到可用端口"""
port = find_free_port(start_port=29500, max_attempts=100)
self.assertIsInstance(port, int)
self.assertGreaterEqual(port, 29500)
self.assertLessEqual(port, 29600)
def test_find_free_port_with_custom_start_port(self):
"""测试使用自定义起始端口"""
port = find_free_port(start_port=30000, max_attempts=50)
self.assertIsInstance(port, int)
self.assertGreaterEqual(port, 30000)
self.assertLessEqual(port, 30050)
def test_find_free_port_start_port_too_low(self):
"""测试起始端口小于1024时抛出异常"""
with self.assertRaises(SchemaValidateError) as context:
find_free_port(start_port=1023)
self.assertIn("start_port must be >= 1024", str(context.exception))
def test_find_free_port_start_port_too_high(self):
"""测试起始端口大于65535时抛出异常"""
with self.assertRaises(SchemaValidateError) as context:
find_free_port(start_port=65536)
self.assertIn("start_port must be <= 65535", str(context.exception))
@patch('socket.socket')
def test_find_free_port_all_ports_in_use(self, mock_socket_class):
"""测试所有端口都被占用时抛出异常"""
mock_socket = MagicMock()
mock_socket.__enter__ = MagicMock(return_value=mock_socket)
mock_socket.__exit__ = MagicMock(return_value=False)
mock_socket.bind.side_effect = OSError("Port in use")
mock_socket_class.return_value = mock_socket
with self.assertRaises(EnvError) as context:
find_free_port(start_port=29500, max_attempts=5)
self.assertIn("Cannot find a free port", str(context.exception))
class TestSetupDistributed(unittest.TestCase):
"""测试 setup_distributed 函数(CPU环境,通过mock模拟设备操作)"""
def test_setup_distributed_hccl_backend(self):
"""测试使用 hccl 后端的分布式设置"""
mock_npu = MagicMock()
mock_init_process_group = MagicMock()
with patch.object(torch, 'npu', mock_npu, create=True), \
patch('msmodelslim.utils.distributed.dist_setup.dist.init_process_group', mock_init_process_group):
setup_distributed(rank=0, world_size=4, backend='hccl', master_port=29500, device_index=0)
self.assertEqual(os.environ['MASTER_ADDR'], '127.0.0.1')
self.assertEqual(os.environ['MASTER_PORT'], '29500')
self.assertEqual(os.environ['RANK'], '0')
self.assertEqual(os.environ['WORLD_SIZE'], '4')
mock_npu.set_device.assert_called_once_with("npu:0")
mock_init_process_group.assert_called_once_with(
backend='hccl',
world_size=4,
rank=0
)
def test_setup_distributed_device_index_none(self):
"""测试 device_index 为 None 时使用 rank 作为设备索引"""
mock_npu = MagicMock()
mock_init_process_group = MagicMock()
with patch.object(torch, 'npu', mock_npu, create=True), \
patch('msmodelslim.utils.distributed.dist_setup.dist.init_process_group', mock_init_process_group):
setup_distributed(rank=2, world_size=4, backend='hccl', master_port=29502, device_index=None)
mock_npu.set_device.assert_called_once_with("npu:2")
mock_init_process_group.assert_called_once()
def test_setup_distributed_device_index_different_from_rank(self):
"""测试 device_index 与 rank 不同的情况"""
mock_npu = MagicMock()
mock_init_process_group = MagicMock()
with patch.object(torch, 'npu', mock_npu, create=True), \
patch('msmodelslim.utils.distributed.dist_setup.dist.init_process_group', mock_init_process_group):
setup_distributed(rank=0, world_size=4, backend='hccl', master_port=29503, device_index=3)
mock_npu.set_device.assert_called_once_with("npu:3")
mock_init_process_group.assert_called_once_with(
backend='hccl',
world_size=4,
rank=0
)
class TestReduceOperation(unittest.TestCase):
"""测试 ReduceOperation 枚举"""
def test_reduce_operation_from_string(self):
"""测试从字符串创建 ReduceOperation"""
self.assertEqual(ReduceOperation("min"), ReduceOperation.MIN)
self.assertEqual(ReduceOperation("max"), ReduceOperation.MAX)
self.assertEqual(ReduceOperation("sum"), ReduceOperation.SUM)
self.assertEqual(ReduceOperation("mean"), ReduceOperation.MEAN)
self.assertEqual(ReduceOperation("prod"), ReduceOperation.PROD)
def test_reduce_operation_invalid_value(self):
"""测试无效值抛出异常"""
with self.assertRaises(ValueError):
ReduceOperation("invalid")
class TestSyncBaseOperation(unittest.TestCase):
"""测试 sync_base_operation 函数(CPU环境,通过mock模拟分布式操作)"""
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_min_with_string(self, mock_all_reduce):
"""测试使用字符串的 min 操作"""
tensor = torch.tensor([1.0, 2.0, 3.0])
result = sync_base_operation(tensor, "min")
mock_all_reduce.assert_called_once_with(tensor, op=dist.ReduceOp.MIN, group=None)
self.assertIs(result, tensor)
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_max_with_string(self, mock_all_reduce):
"""测试使用字符串的 max 操作"""
tensor = torch.tensor([1.0, 2.0, 3.0])
result = sync_base_operation(tensor, "max")
mock_all_reduce.assert_called_once_with(tensor, op=dist.ReduceOp.MAX, group=None)
self.assertIs(result, tensor)
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_sum_with_string(self, mock_all_reduce):
"""测试使用字符串的 sum 操作"""
tensor = torch.tensor([1.0, 2.0, 3.0])
result = sync_base_operation(tensor, "sum")
mock_all_reduce.assert_called_once_with(tensor, op=dist.ReduceOp.SUM, group=None)
self.assertIs(result, tensor)
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_prod_with_string(self, mock_all_reduce):
"""测试使用字符串的 prod 操作"""
tensor = torch.tensor([1.0, 2.0, 3.0])
result = sync_base_operation(tensor, "prod")
mock_all_reduce.assert_called_once_with(tensor, op=dist.ReduceOp.PRODUCT, group=None)
self.assertIs(result, tensor)
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_world_size')
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_mean_with_string(self, mock_all_reduce, mock_get_world_size):
"""测试使用字符串的 mean 操作"""
mock_get_world_size.return_value = 2
tensor = torch.tensor([2.0, 4.0, 6.0])
result = sync_base_operation(tensor, "mean")
mock_all_reduce.assert_called_once_with(tensor, op=dist.ReduceOp.SUM, group=None)
mock_get_world_size.assert_called_once_with(None)
self.assertIs(result, tensor)
def test_sync_base_operation_invalid_string(self):
"""测试无效操作字符串抛出异常"""
tensor = torch.tensor([1.0, 2.0, 3.0])
with self.assertRaises(UnsupportedError) as context:
sync_base_operation(tensor, "invalid_op")
self.assertIn("Unsupported operation", str(context.exception))
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_reduce')
def test_sync_base_operation_case_insensitive(self, mock_all_reduce):
"""测试字符串操作不区分大小写"""
tensor = torch.tensor([1.0, 2.0, 3.0])
sync_base_operation(tensor, "MIN")
sync_base_operation(tensor, "Min")
sync_base_operation(tensor, "mIn")
self.assertEqual(mock_all_reduce.call_count, 3)
class TestSyncGatherTensors(unittest.TestCase):
"""测试 sync_gather_tensors 函数(CPU环境,通过mock模拟分布式操作)"""
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_rank')
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_world_size')
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_gather')
def test_sync_gather_tensors_same_shape_on_device(self, mock_all_gather, mock_get_world_size, mock_get_rank):
"""测试在设备上聚合相同形状的张量"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
tensor = torch.tensor([1.0, 2.0, 3.0])
def all_gather_side_effect(tensor_list, tensor, group=None):
for _, t in enumerate(tensor_list):
t.copy_(tensor)
return None
mock_all_gather.side_effect = all_gather_side_effect
result = sync_gather_tensors(tensor, variable_shapes=False, on_cpu=False)
self.assertEqual(len(result), 2)
mock_all_gather.assert_called_once()
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_rank')
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_world_size')
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_gather_object')
def test_sync_gather_tensors_on_cpu(self, mock_all_gather_object, mock_get_world_size, mock_get_rank):
"""测试在 CPU 上聚合张量"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
tensor = torch.tensor([1.0, 2.0, 3.0])
def all_gather_object_side_effect(tensor_list, tensor_cpu, group=None):
tensor_list[0] = tensor_cpu
tensor_list[1] = tensor_cpu
return None
mock_all_gather_object.side_effect = all_gather_object_side_effect
result = sync_gather_tensors(tensor, on_cpu=True)
self.assertEqual(len(result), 2)
mock_all_gather_object.assert_called_once()
@patch('msmodelslim.utils.distributed.dist_ops.torch.device')
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_rank')
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_world_size')
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_gather')
def test_sync_gather_tensors_variable_shapes(
self, mock_all_gather, mock_get_world_size, mock_get_rank, mock_device):
"""测试聚合不同形状的张量(CPU环境)"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
mock_device.return_value.__enter__ = MagicMock(return_value=None)
mock_device.return_value.__exit__ = MagicMock(return_value=False)
tensor = torch.tensor([1.0, 2.0, 3.0])
call_count = [0]
def all_gather_side_effect(tensor_list, tensor, group=None):
call_count[0] += 1
if call_count[0] == 1:
for t in tensor_list:
t.copy_(torch.tensor([3], dtype=torch.long))
else:
for t in tensor_list:
t.copy_(tensor)
return None
mock_all_gather.side_effect = all_gather_side_effect
result = sync_gather_tensors(tensor, variable_shapes=True, on_cpu=False)
self.assertEqual(len(result), 2)
self.assertEqual(mock_all_gather.call_count, 2)
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_rank')
@patch('msmodelslim.utils.distributed.dist_ops.dist.get_world_size')
@patch('msmodelslim.utils.distributed.dist_ops.dist.all_gather')
def test_sync_gather_tensors_2d_tensor(self, mock_all_gather, mock_get_world_size, mock_get_rank):
"""测试聚合二维张量"""
mock_get_world_size.return_value = 2
mock_get_rank.return_value = 0
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
def all_gather_side_effect(tensor_list, tensor, group=None):
for t in tensor_list:
t.copy_(tensor)
return None
mock_all_gather.side_effect = all_gather_side_effect
result = sync_gather_tensors(tensor, variable_shapes=False, on_cpu=False)
self.assertEqual(len(result), 2)
for t in result:
self.assertEqual(t.shape, torch.Size([2, 2]))
if __name__ == '__main__':
unittest.main()