"""Define multi parameter feature of pipeline parallel training.
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.
"""
from argparse import ArgumentParser, Namespace
import torch
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import MindSpeedPatchesManager
class MultiParameterFeature(MindSpeedFeature):
"""Multi parameter feature of pipeline parallel training."""
def __init__(
self,
feature_name: str = "use-multiparameter-pipeline-model-parallel",
optimization_level: int = 2,
):
super().__init__(feature_name, optimization_level)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument(
"--use-multiparameter-pipeline-model-parallel",
action="store_true",
default=False,
help="can transfer multi parameters from "
"stage to stage in pipeline model parallel",
)
def validate_args(self, args):
self.incompatible_check(args, 'moe_fb_overlap')
if getattr(args, "use_multiparameter_pipeline_model_parallel", False) and (not hasattr(args, 'pipeline_tensor_shapes') or args.pipeline_tensor_shapes is None):
if getattr(args, "schedules_method", False) == "dualpipev":
raise AssertionError(
"The dualpipev and use_multiparameter_pipeline_model_parallel are incompatible."
)
tensor_shape = (int(args.seq_length / args.context_parallel_size), args.micro_batch_size, args.hidden_size)
if getattr(args, "bf16", False):
dtype = torch.bfloat16
elif getattr(args, "fp16", False):
dtype = torch.float16
else:
dtype = torch.float32
args.pipeline_tensor_shapes = [{"shape": tensor_shape, "dtype": dtype}]
def register_patches(
self,
patch_manager: MindSpeedPatchesManager,
args: Namespace,
):
from mindspeed.core.pipeline_parallel.multi_parameter.adaptor import (
get_tensor_shapes_wrapper,
forward_step_wrapper,
mindspeed_backward_step,
mindspeed_recv_forward,
mindspeed_recv_backward,
mindspeed_send_forward,
mindspeed_send_backward,
mindspeed_send_forward_recv_backward,
mindspeed_send_backward_recv_forward,
get_forward_backward_func_wrapper,
core_transformer_config_from_args_wrapper,
)
if getattr(args, self.feature_name, None):
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.get_tensor_shapes",
get_tensor_shapes_wrapper,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.forward_step",
forward_step_wrapper,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.backward_step",
mindspeed_backward_step,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.recv_forward",
mindspeed_recv_forward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.recv_backward",
mindspeed_recv_backward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.send_forward",
mindspeed_send_forward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.send_backward",
mindspeed_send_backward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.send_forward_recv_backward",
mindspeed_send_forward_recv_backward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.send_backward_recv_forward",
mindspeed_send_backward_recv_forward,
)
patch_manager.register_patch(
"megatron.core.pipeline_parallel.schedules.get_forward_backward_func",
get_forward_backward_func_wrapper,
)
patch_manager.register_patch(
"megatron.training.arguments.core_transformer_config_from_args",
core_transformer_config_from_args_wrapper,
)