import re
from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MoEFwdBwdOverlapFeature(MindSpeedFeature):
def __init__(self):
super().__init__('moe-fb-overlap')
@staticmethod
def _has_pipeline_model_parallel_layout(args):
return getattr(args, 'pipeline_model_parallel_layout', None) is not None
@staticmethod
def _get_pipeline_model_parallel_layout_stage_count(layout):
layout = layout.replace(',', '')
patterns = [
r"\(([^)]+)\)\*(\d+)",
r"(.)\*(\d+)",
]
for pattern in patterns:
layout = re.sub(pattern, lambda x: x.group(1) * int(x.group(2)), layout)
return len(layout.split('|'))
@staticmethod
def _has_virtual_pipeline(args):
if MoEFwdBwdOverlapFeature._has_pipeline_model_parallel_layout(args):
if getattr(args, 'virtual_pipeline_model_parallel_size', None) is not None:
return True
pp_size = int(getattr(args, 'pipeline_model_parallel_size', 1))
num_stages = MoEFwdBwdOverlapFeature._get_pipeline_model_parallel_layout_stage_count(
args.pipeline_model_parallel_layout
)
return num_stages % pp_size == 0 and num_stages // pp_size > 1
return getattr(args, 'num_layers_per_virtual_pipeline_stage', None) is not None
@staticmethod
def _validate_pipeline_model_parallel_layout_for_fb_overlap(args):
if not args.moe_fb_overlap or not MoEFwdBwdOverlapFeature._has_pipeline_model_parallel_layout(args):
return
if getattr(args, 'noop_layers', None) is not None:
raise AssertionError(
'--noop-layers is not supported with --pipeline-model-parallel-layout and --moe-fb-overlap now.'
)
from mindspeed.core.pipeline_parallel.pipeline_model_parallel_layout.adaptor import LayerType
from mindspeed.core.pipeline_parallel.pipeline_model_parallel_layout.layout import (
PipelineParallelLayerLayout,
)
layout = PipelineParallelLayerLayout(
args.pipeline_model_parallel_layout,
args.pipeline_model_parallel_size,
)
empty_decoder_chunks = []
for pp_rank, pp_layout in enumerate(layout.layout):
for vpp_rank, chunk in enumerate(pp_layout):
if chunk.count(LayerType.decoder) == 0:
empty_decoder_chunks.append(f'pp_rank={pp_rank}, vpp_rank={vpp_rank}')
if empty_decoder_chunks:
raise AssertionError(
'--moe-fb-overlap does not support --pipeline-model-parallel-layout '
'with empty decoder chunks now. Empty chunks: ' + ', '.join(empty_decoder_chunks)
)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--moe-fb-overlap', action='store_true')
group.add_argument('--moe-unperm2-mem-optim-swap', action='store_true')
def validate_args(self, args):
self.incompatible_check(args, 'moe_alltoall_overlap_comm')
self.incompatible_check(args, 'overlap_grad_reduce')
self.incompatible_check(args, 'moe_hierarchical_alltoallv')
self.incompatible_check(args, 'moe_zero_memory_num_layers')
self.incompatible_check(args, 'use_nanopipe')
self.incompatible_check(args, 'automated_pipeline')
self.incompatible_check(args, 'recompute_in_bubble')
self.incompatible_check(args, 'recompute_in_advance')
self.incompatible_check(args, 'use_legacy_models')
self.incompatible_check(args, 'moe_tp_extend_ep')
self.incompatible_check(args, 'swap_attention')
self.dependency_check(args, 'moe_grouped_gemm')
if args.moe_fb_overlap and args.moe_token_dispatcher_type in ['allgather', 'alltoall_seq']:
raise AssertionError('The fb overlap feature do not support allgather and alltoall_seq dispatcher.')
if args.moe_fb_overlap and (args.expert_tensor_parallel_size != 1 or args.expert_model_parallel_size == 1):
raise AssertionError(
'fb overlap only support expert-tensor-parallel-size=1 and expert-model-parallel-size > 1'
)
if args.moe_unperm2_mem_optim_swap and not args.moe_fb_overlap:
raise AssertionError('--moe-unperm2-mem-optim-swap currently only can be used with --moe-fb-overlap')
self._validate_pipeline_model_parallel_layout_for_fb_overlap(args)
incorrect_schedule = (
getattr(args, 'schedules_method', None) != 'dualpipev'
and not self._has_virtual_pipeline(args)
and int(getattr(args, 'pipeline_model_parallel_size', 1)) != 1
)
if args.moe_fb_overlap and incorrect_schedule:
raise AssertionError('The fb overlap needs no pipeline, virtual pipeline or dualpipeV schedules.')
if getattr(args, 'virtual_pipeline_model_parallel_size', None) is not None and args.moe_fb_overlap:
if (
not args.global_batch_size
// (args.micro_batch_size * args.pipeline_model_parallel_size * args.data_parallel_size)
> 1
):
raise ValueError(f"""In VPP schedule,
fb_overlap needs global_batch_size // (micro_batch_size * pipeline_model_parallel_size * data_parallel_size) > 1.
The global_batch_size is {args.global_batch_size},
but the micro_batch_size is {args.micro_batch_size}, PP size is {args.pipeline_model_parallel_size},DP size is {args.data_parallel_size}.
""")
def post_validate_args(self, args):
if (
args.noop_layers is not None
and args.moe_fb_overlap
and getattr(args, 'schedules_method', None) != 'dualpipev'
):
if self._has_pipeline_model_parallel_layout(args):
raise AssertionError(
'--noop-layers is not supported with --pipeline-model-parallel-layout and --moe-fb-overlap now.'
)
noop_layers_list = list(args.noop_layers)
if noop_layers_list[0] < (args.num_layers - args.num_layers_per_virtual_pipeline_stage):
raise AssertionError('In VPP schedule with fb_overlap, the noop-layers must in last VPP stage.')
def register_patches(self, patch_manager, args):
if getattr(args, self.feature_name, None):
from mindspeed.core.transformer.moe.moe_feature.fb_overlap import (
linear_backward_wgrad_detach,
transformer_block_fb_overlap_init_wrapper,
mtp_block_fb_overlap_forward_wrapper,
dualpipev_fb_overlap_mtp_layer_forward,
)
from mindspeed.core.transformer.moe.moe_feature.fb_overlap.adaptor import (
_make_backward_post_hook,
get_moe_module_spec_wrapper,
get_forward_backward_func_vpp_overlap_wrapper,
)
patch_manager.register_patch(
'megatron.core.models.gpt.moe_module_specs.get_moe_module_spec', get_moe_module_spec_wrapper
)
patch_manager.register_patch(
'megatron.core.transformer.transformer_block.TransformerBlock.__init__',
transformer_block_fb_overlap_init_wrapper,
)
patch_manager.register_patch(
'megatron.core.tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.backward',
linear_backward_wgrad_detach,
)
patch_manager.register_patch(
'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_backward_post_hook',
_make_backward_post_hook,
)
if self._has_virtual_pipeline(args) or int(getattr(args, 'pipeline_model_parallel_size', 1)) == 1:
patch_manager.register_patch(
'megatron.core.pipeline_parallel.schedules.get_forward_backward_func',
get_forward_backward_func_vpp_overlap_wrapper,
)
if getattr(args, 'mtp_num_layers', None):
patch_manager.register_patch(
'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionBlock.forward',
mtp_block_fb_overlap_forward_wrapper,
)
patch_manager.register_patch(
'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.forward',
dualpipev_fb_overlap_mtp_layer_forward,
)