from argparse import Namespace
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
class VTPFeature(MindSpeedFeature):
"""Virtual Tensor Parallelism (VTP) feature.
Enables different PP stages to have different TP sizes in LDT SFT mode.
VTP sizes are auto-detected from per-node GPU topology after distributed init.
Requires --layerwise-disaggregated-training to be enabled.
--tensor-model-parallel-size should be set to the maximum TP size across nodes.
"""
def __init__(self):
super().__init__(feature_name="virtual-tp", optimization_level=0)
def pre_validate_args(self, args: Namespace):
"""Inflate world_size if LDT enabled and world_size not divisible by TP*PP*CP.
VTP sizes are auto-detected after distributed init via all_gather.
At this stage we only ensure Megatron validation passes by inflating
world_size to TP*PP*CP (DP=1 minimum valid value).
"""
ldt = getattr(args, 'layerwise_disaggregated_training', False)
if not ldt:
return
world_size = getattr(args, 'world_size', None)
if world_size is None:
return
tp = args.tensor_model_parallel_size
pp = args.pipeline_model_parallel_size
cp = getattr(args, 'context_parallel_size', 1) or 1
if world_size % (tp * pp * cp) == 0:
return
args._vtp_orig_world_size = world_size
args.world_size = tp * pp * cp
def post_validate_args(self, args: Namespace):
"""Restore real world_size after Megatron validation."""
orig = getattr(args, '_vtp_orig_world_size', None)
if orig is not None:
args.world_size = orig
del args._vtp_orig_world_size
def register_patches(
self,
patch_manager: MindSpeedPatchesManager,
args: Namespace,
):
ldt = getattr(args, 'layerwise_disaggregated_training', False)
if not ldt:
return
from mindspeed_llm.core.layerwise_disaggregated_training.utils import vtp_all_gather_into_tensor_wrapper
patch_manager.register_patch('torch.distributed.all_gather_into_tensor', vtp_all_gather_into_tensor_wrapper)