"""Unit tests for VTP-specific functions in schedules_patch."""
import os
from unittest.mock import patch, MagicMock
import mindspeed.megatron_adaptor
import torch
from mindspeed_mm.patchs.layerwise_disaggregated_training import parallel_state_patch as mod
from tests.ut.utils import judge_expression
mpu = mod.mpu
class TestVtpAllreduce:
def test_full_3_step(self):
with patch.object(mpu, "get_tensor_model_parallel_world_size", return_value=2), \
patch.object(mpu, "get_tensor_model_parallel_group", return_value=MagicMock()), \
patch.object(mpu, "get_pipeline_model_parallel_group", return_value=MagicMock()), \
patch.object(mod.torch.distributed, "all_reduce") as mock_ar, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast, \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod, "get_vtp_stage_ranks", return_value=[[0, 1], [2, 3]]), \
patch.object(mod, "get_vtp_my_stage_idx", return_value=0):
mod.vtp_allreduce(torch.tensor([1.0]))
judge_expression(mock_ar.call_count == 2)
mock_bcast.assert_called_once()
def test_tp1_no_intra_group(self):
with patch.object(mpu, "get_tensor_model_parallel_world_size", return_value=1), \
patch.object(mpu, "get_pipeline_model_parallel_group", return_value=MagicMock()), \
patch.object(mod.torch.distributed, "all_reduce") as mock_ar, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast, \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mpu, "get_tensor_model_parallel_group", return_value=None):
mod.vtp_allreduce(torch.tensor([1.0]))
mock_ar.assert_called_once()
mock_bcast.assert_not_called()
def test_non_rank0_skips_pp(self):
with patch.object(mpu, "get_tensor_model_parallel_world_size", return_value=1), \
patch.object(mod.torch.distributed, "all_reduce") as mock_ar, \
patch.object(mod.torch.distributed, "broadcast") as mock_bcast, \
patch.object(mod, "is_vtp_stage_rank0", return_value=False), \
patch.object(mpu, "get_tensor_model_parallel_group", return_value=None):
mod.vtp_allreduce(torch.tensor([1.0]))
mock_ar.assert_not_called()
mock_bcast.assert_not_called()
class TestVtpHierarchicalBarrier:
def test_full_3_step_barrier(self):
with patch.object(mpu, "get_tensor_model_parallel_world_size", return_value=2), \
patch.object(mpu, "get_tensor_model_parallel_group", return_value=MagicMock()), \
patch.object(mpu, "get_pipeline_model_parallel_group", return_value=MagicMock()), \
patch.object(mod.torch.distributed, "barrier") as mock_barrier, \
patch.object(mod, "is_vtp_stage_rank0", return_value=True), \
patch.object(mod, "get_vtp_intra_stage_group", return_value=MagicMock()):
mod.vtp_hierarchical_barrier()
judge_expression(mock_barrier.call_count == 3)
def test_tp1_non_rank0_no_intra(self):
with patch.object(mpu, "get_tensor_model_parallel_world_size", return_value=1), \
patch.object(mod.torch.distributed, "barrier") as mock_barrier, \
patch.object(mod, "is_vtp_stage_rank0", return_value=False), \
patch.object(mod, "get_vtp_intra_stage_group", return_value=None):
mod.vtp_hierarchical_barrier()
mock_barrier.assert_not_called()
class TestAutoDetectVtpSizes:
def _make_args(self, world_size, tp, pp, cp=1, ep=1, dp=1):
args = MagicMock()
args.world_size = world_size
args.tensor_model_parallel_size = tp
args.pipeline_model_parallel_size = pp
args.context_parallel_size = cp
args.expert_model_parallel_size = ep
args.data_parallel_size = dp
return args
def test_heterogeneous_edge_cloud(self):
with patch.object(mod.torch.distributed, "all_gather") as mock_ag, \
patch.object(mod.torch.distributed, "get_world_size", return_value=9), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.dict(os.environ, {"LOCAL_WORLD_SIZE": "0"}):
args = self._make_args(9, 4, 3, cp=1, ep=1, dp=1)
def fill(gathered, local_tensor):
vals = [1] + [8] * 8
for i, t in enumerate(gathered):
t.fill_(vals[i])
mock_ag.side_effect = fill
result = mod._auto_detect_vtp_sizes(args)
judge_expression(result == [1, 4, 4])
def test_homogeneous_returns_none(self):
with patch.object(mod.torch.distributed, "all_gather") as mock_ag, \
patch.object(mod.torch.distributed, "get_world_size", return_value=8), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.dict(os.environ, {"LOCAL_WORLD_SIZE": "8"}):
args = self._make_args(8, 4, 2, cp=1, ep=1, dp=1)
def fill(gathered, local_tensor):
for t in gathered:
t.fill_(8)
mock_ag.side_effect = fill
result = mod._auto_detect_vtp_sizes(args)
judge_expression(result == [4, 4])
def test_not_enough_cards(self):
with patch.object(mod.torch.distributed, "all_gather") as mock_ag, \
patch.object(mod.torch.distributed, "get_world_size", return_value=2), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.dict(os.environ, {"LOCAL_WORLD_SIZE": "1"}):
args = self._make_args(2, 4, 4, cp=1, ep=1, dp=1)
def fill(gathered, local_tensor):
for t in gathered:
t.fill_(1)
mock_ag.side_effect = fill
result = mod._auto_detect_vtp_sizes(args)
judge_expression(isinstance(result, list))
judge_expression(result[0] < 0)
def test_max_tp_mismatch_returns_none(self):
with patch.object(mod.torch.distributed, "all_gather") as mock_ag, \
patch.object(mod.torch.distributed, "get_world_size", return_value=4), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.dict(os.environ, {"LOCAL_WORLD_SIZE": "2"}):
args = self._make_args(4, 8, 2, cp=1, ep=1, dp=1)
def fill(gathered, local_tensor):
for t in gathered:
t.fill_(2)
mock_ag.side_effect = fill
result = mod._auto_detect_vtp_sizes(args)
judge_expression(isinstance(result, list))
judge_expression(result[0] < 0)
class TestPreValidateArgsForVtp:
def test_divisible_noop(self):
args = MagicMock()
args.world_size = 8
args.tensor_model_parallel_size = 2
args.pipeline_model_parallel_size = 2
args.context_parallel_size = 1
mod.pre_validate_args_for_vtp(args)
judge_expression(args.world_size == 8)
def test_inflates_world_size(self):
args = MagicMock()
args.world_size = 9
args.tensor_model_parallel_size = 4
args.pipeline_model_parallel_size = 2
args.context_parallel_size = 1
mod.pre_validate_args_for_vtp(args)
judge_expression(args._vtp_orig_world_size == 9)
judge_expression(args.world_size == 8)
def test_no_world_size_noop(self):
args = MagicMock(spec=[])
args.world_size = None
mod.pre_validate_args_for_vtp(args)
def test_cp_none_defaults_to_1(self):
args = MagicMock()
args.world_size = 9
args.tensor_model_parallel_size = 4
args.pipeline_model_parallel_size = 2
args.context_parallel_size = None
mod.pre_validate_args_for_vtp(args)
judge_expression(args.world_size == 8)
class TestPostValidateArgsForVtp:
_ENV = {
"GROUP_RANK": "0",
"RANK": "0",
"WORLD_SIZE": "8",
"LOCAL_WORLD_SIZE": "2",
"GROUP_WORLD_SIZE": "2",
}
def test_restores_world_size(self):
args = MagicMock()
args._vtp_orig_world_size = 9
args.world_size = 8
args.tensor_model_parallel_size = 2
args.pipeline_model_parallel_size = 2
args.context_parallel_size = 1
with patch.dict(os.environ, self._ENV):
mod.post_validate_args_for_vtp(args)
judge_expression(args.world_size == 9)
def test_no_attr_noop(self):
args = MagicMock(spec=[])
args.world_size = 8
args.tensor_model_parallel_size = 2
args.pipeline_model_parallel_size = 2
args.context_parallel_size = 1
with patch.dict(os.environ, self._ENV):
mod.post_validate_args_for_vtp(args)
judge_expression(args.world_size == 8)
class TestInitializeModelParallelWrapper:
_ENV = {
"GROUP_RANK": "0",
"RANK": "0",
"WORLD_SIZE": "8",
"LOCAL_WORLD_SIZE": "8",
"GROUP_WORLD_SIZE": "1",
}
@staticmethod
def _reset_ldt_globals():
"""Reset LDT/VTP/VDP module-level globals that persist between tests."""
mod._PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_CLOUD_TP = None
mod._PIPELINE_MODEL_PARALLEL_GROUP_FOR_VDP_CROSS_EDGE_CLOUD = None
mod._LAYERWISE_DISAGGREGATED_TRAINING = False
mod._VDP_SIZE = 1
mod._VDP_ENABLED = False
mod._VTP_ENABLED = False
mod._VTP_SIZE_LIST = None
mod._VTP_STAGE_RANKS = None
mod._VTP_INTRA_STAGE_GROUP = None
mod._VTP_MY_STAGE_IDX = None
mod._EDGE_TP_SIZE = 1
mod._PIPELINE_MODEL_PARALLEL_GROUP_ALTERNATE = None
mod._PIPELINE_MODEL_PARALLEL_GROUP_FOR_LAST_TO_FIRST = None
mod._PIPELINE_MODEL_PARALLEL_GROUP_FOR_FIRST_TO_LAST = None
def setup_method(self):
self._reset_ldt_globals()
def test_non_ldt_calls_fn_only(self):
args = MagicMock(layerwise_disaggregated_training=False, data_parallel_size=1)
with patch.object(mod, "get_args", return_value=args), \
patch.object(mod, "_auto_detect_vtp_sizes", return_value=None), \
patch.object(mod.torch.distributed, "is_initialized", return_value=True), \
patch.object(mod.torch.distributed, "get_world_size", return_value=8), \
patch.object(mod.torch.distributed, "get_rank", return_value=0), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.object(mod.torch.distributed, "new_group"), \
patch.object(mod.torch.distributed, "broadcast_object_list"), \
patch.object(mod, "_init_vdp_state"), \
patch.object(mod, "get_args", return_value=MagicMock(
layerwise_disaggregated_training=False, data_parallel_size=1, vdp=0,
vtp_size_list=None, vtp_stage_ranks=None, vdp_size=1,
context_parallel_size=1
)), \
patch.dict(os.environ, self._ENV):
fn = MagicMock()
wrapper = mod.initialize_model_parallel_wrapper(fn)
wrapper(4, 2)
fn.assert_called_once()
def test_ldt_uniform_calls_fn_and_group_init(self):
args = MagicMock(layerwise_disaggregated_training=True, data_parallel_size=1)
with patch.object(mod, "get_args", return_value=args), \
patch.object(mod, "_auto_detect_vtp_sizes", return_value=None), \
patch.object(mod.torch.distributed, "is_initialized", return_value=True), \
patch.object(mod.torch.distributed, "get_world_size", return_value=8), \
patch.object(mod.torch.distributed, "get_rank", return_value=0), \
patch.object(mod.torch.cuda, "current_device", return_value=0), \
patch.object(mod.torch.distributed, "new_group"), \
patch.object(mod.torch.distributed, "broadcast_object_list"), \
patch.object(mod, "_init_vdp_state"), \
patch.dict(os.environ, self._ENV):
fn = MagicMock()
wrapper = mod.initialize_model_parallel_wrapper(fn)
wrapper(4, 2)
fn.assert_called_once()
def test_ldt_non_uniform_calls_vtp_static(self):
args = MagicMock(layerwise_disaggregated_training=True)
with patch.object(mod, "get_args", return_value=args), \
patch.object(mod, "_auto_detect_vtp_sizes", return_value=[1, 4, 4]), \
patch.object(mod, "is_vdp_enabled", return_value=False), \
patch.object(mod, "_initialize_vtp_static_only_vtp") as mock_static, \
patch.object(mod.torch.distributed, "get_world_size", return_value=9), \
patch.dict(os.environ, self._ENV):
fn = MagicMock()
wrapper = mod.initialize_model_parallel_wrapper(fn)
wrapper(4, 3)
mock_static.assert_called_once()
fn.assert_not_called()