from argparse import ArgumentParser, Namespace
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
class PipelineModelParallelLayoutFeature(MindSpeedFeature):
"""Support --pipeline-model-parallel-layout from Megatron dev on core_r0.12.1."""
_STATE_ATTR = "_mindspeed_pipeline_model_parallel_layout_state"
def __init__(
self,
feature_name: str = "pipeline-model-parallel-layout",
optimization_level: int = 0,
):
super().__init__(feature_name, optimization_level)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title="pipeline-model-parallel-layout")
group.add_argument(
"--pipeline-model-parallel-layout",
type=str,
default=None,
help=(
"A string that describes a custom pipeline model parallel layout. "
'e.g., "E|(t|)*3,m|m||L". E, L, t, m denotes embedding, loss, transformer '
'decoder layer, and mtp layer, respectively. Stages are split by "|". '
"Replicated stages or layers can be described with multiplication. "
"Commas can be used cosmetically. "
"Default None is not using this argument to set the layout."
),
)
def pre_validate_args(self, args: Namespace):
if (
getattr(args, self.feature_name, None) is None
or not isinstance(args, Namespace)
or not hasattr(args, "pipeline_model_parallel_size")
):
return
setattr(
args,
self._STATE_ATTR,
{
"num_layers_per_virtual_pipeline_stage": getattr(args, "num_layers_per_virtual_pipeline_stage", None),
"num_virtual_stages_per_pipeline_rank": getattr(args, "num_virtual_stages_per_pipeline_rank", None),
"decoder_first_pipeline_num_layers": getattr(args, "decoder_first_pipeline_num_layers", None),
"decoder_last_pipeline_num_layers": getattr(args, "decoder_last_pipeline_num_layers", None),
"overlap_p2p_comm": getattr(args, "overlap_p2p_comm", None),
"align_param_gather": getattr(args, "align_param_gather", None),
},
)
args.num_layers_per_virtual_pipeline_stage = None
from mindspeed.core.pipeline_parallel.pipeline_model_parallel_layout.layout import (
PipelineParallelLayerLayout,
)
num_stages = PipelineParallelLayerLayout.get_num_stages_from_str(args.pipeline_model_parallel_layout)
detected_vpp_size = num_stages // args.pipeline_model_parallel_size
if detected_vpp_size > 1:
args.num_virtual_stages_per_pipeline_rank = detected_vpp_size
else:
args.num_virtual_stages_per_pipeline_rank = None
if (
getattr(args, "decoder_first_pipeline_num_layers", None) is None
and getattr(args, "decoder_last_pipeline_num_layers", None) is None
):
args.decoder_first_pipeline_num_layers = 1
def post_validate_args(self, args: Namespace):
state = getattr(args, self._STATE_ATTR, None)
if state is None or not isinstance(args, Namespace) or not hasattr(args, "pipeline_model_parallel_size"):
return
for key, value in state.items():
if key in (
"num_layers_per_virtual_pipeline_stage",
"num_virtual_stages_per_pipeline_rank",
"overlap_p2p_comm",
"align_param_gather",
):
continue
setattr(args, key, value)
def validate_args(self, args: Namespace):
if (
getattr(args, self.feature_name, None) is None
or not isinstance(args, Namespace)
or not hasattr(args, "pipeline_model_parallel_size")
):
return
if getattr(args, "schedules_method", None) == "dualpipev":
raise AssertionError("--pipeline-model-parallel-layout is incompatible with --schedules-method dualpipev.")
if getattr(args, "pipeline_num_transformer_layers", None) is not None:
raise AssertionError(
"--pipeline-model-parallel-layout is incompatible with --pipeline-num-transformer-layers."
)
if getattr(args, "noop_layers", None) is not None:
raise AssertionError("--pipeline-model-parallel-layout is incompatible with --noop-layers.")
if getattr(args, "recompute_in_bubble", False) or getattr(args, "recompute_in_advance", False):
raise AssertionError(
"--pipeline-model-parallel-layout is not supported with "
"--recompute-in-bubble or --recompute-in-advance now."
)
state = getattr(args, self._STATE_ATTR, {})
num_layers_per_virtual_pipeline_stage = state.get(
"num_layers_per_virtual_pipeline_stage",
getattr(args, "num_layers_per_virtual_pipeline_stage", None),
)
num_virtual_stages_per_pipeline_rank = state.get(
"num_virtual_stages_per_pipeline_rank",
getattr(args, "num_virtual_stages_per_pipeline_rank", None),
)
assert (
int(num_layers_per_virtual_pipeline_stage is not None)
+ int(num_virtual_stages_per_pipeline_rank is not None)
+ int(args.pipeline_model_parallel_layout is not None)
) <= 1, (
"No more than one of the following arguments can be set at the same time: "
"--num-layers-per-virtual-pipeline-stage, --num-virtual-stages-per-pipeline-rank,"
"--pipeline-model-parallel-layout. "
f"{num_layers_per_virtual_pipeline_stage=}, "
f"{num_virtual_stages_per_pipeline_rank=}, "
f"{args.pipeline_model_parallel_layout=}."
)
if state:
args.num_layers_per_virtual_pipeline_stage = num_layers_per_virtual_pipeline_stage
args.num_virtual_stages_per_pipeline_rank = num_virtual_stages_per_pipeline_rank
from mindspeed.core.pipeline_parallel.pipeline_model_parallel_layout.layout import (
PipelineParallelLayerLayout,
)
num_stages = PipelineParallelLayerLayout.get_num_stages_from_str(args.pipeline_model_parallel_layout)
assert num_stages % args.pipeline_model_parallel_size == 0, (
f"The length of pipeline_model_parallel_layout must be divisible"
f" by pipeline_model_parallel_size ({num_stages=},"
f" {args.pipeline_model_parallel_size=})"
)
args.virtual_pipeline_model_parallel_size = num_stages // args.pipeline_model_parallel_size
if args.virtual_pipeline_model_parallel_size == 1:
args.virtual_pipeline_model_parallel_size = None
elif getattr(args, "optimize_send_recv_comm", False):
raise AssertionError(
"--pipeline-model-parallel-layout with virtual pipeline is incompatible with --optimize-send-recv-comm."
)
original_overlap_p2p_comm = state.get("overlap_p2p_comm", getattr(args, "overlap_p2p_comm", None))
original_align_param_gather = state.get("align_param_gather", getattr(args, "align_param_gather", None))
if original_overlap_p2p_comm is not None:
args.overlap_p2p_comm = original_overlap_p2p_comm
if original_align_param_gather is not None:
args.align_param_gather = original_align_param_gather
if args.virtual_pipeline_model_parallel_size is not None:
if args.overlap_p2p_comm:
assert args.pipeline_model_parallel_size > 1, (
"When interleaved schedule is used, pipeline-model-parallel size should be greater than 1"
)
else:
assert args.pipeline_model_parallel_size > 2, (
"When interleaved schedule is used and p2p communication overlap is disabled, "
"pipeline-model-parallel size should be greater than 2 to avoid having multiple "
"p2p sends and recvs between same 2 ranks per communication batch"
)
else:
args.overlap_p2p_comm = False
args.align_param_gather = False
def register_patches(
self,
patch_manager: MindSpeedPatchesManager,
args: Namespace,
):
if not getattr(args, self.feature_name, None):
return
from mindspeed.core.pipeline_parallel.pipeline_model_parallel_layout.adaptor import (
LayerType,
get_num_layers_to_build_wrapper,
get_transformer_layer_offset_wrapper,
transformer_config_post_init_wrapper,
)
patch_manager.register_patch("megatron.core.transformer.enums.LayerType", LayerType)
patch_manager.register_patch(
"megatron.core.transformer.transformer_config.TransformerConfig.__post_init__",
transformer_config_post_init_wrapper,
)
patch_manager.register_patch(
"megatron.core.transformer.transformer_block.get_num_layers_to_build",
get_num_layers_to_build_wrapper,
)
patch_manager.register_patch(
"megatron.core.transformer.transformer_layer.get_transformer_layer_offset",
get_transformer_layer_offset_wrapper,
)