from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MoEZeroMemoryFeature(MindSpeedFeature):
'''
Zero-Memory Settings spec.
This spec supports "alltoall" and "alltoallseq" dispatcher.
'''
def __init__(self):
super().__init__('moe-zero-memory', 2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument("--moe-zero-memory", type=str, default='disable',
choices=['disable', 'level0', 'level1'],
help="Set level for saving activation memory in moe layer.")
group.add_argument('--moe-zero-memory-num-layers', type=int, default=None,
help='the number of layers using moe-zero-memory level1'
'in each pp stage.')
def pre_validate_args(self, args):
if args.moe_zero_memory_num_layers is not None:
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
if args.moe_zero_memory_num_layers < 0 or args.moe_zero_memory_num_layers > num_layers_per_pipeline_stage:
raise AssertionError('`--moe-zero-memory-num-layers` must be between 0 and num layers per pipeline stage')
if args.moe_zero_memory == "disable":
raise AssertionError('`--moe-zero-memory` must be enabled when using `--moe-zero-memory-num-layers`')
if args.moe_zero_memory != "disable" and not (args.moe_alltoall_overlap_comm or args.moe_fb_overlap):
raise AssertionError('`--moe-zero-memory` only support `--moe-alltoall-overlap-comm` or `--moe-fb-overlap` for now.')
def register_patches(self, patch_manager, args):
if args.moe_zero_memory != 'disable' and args.moe_alltoall_overlap_comm:
from mindspeed.core.transformer.moe.moe_feature.overlap.experts import zero_memory_shared_expert_mlp_forward
patch_manager.register_patch(
'megatron.core.transformer.moe.shared_experts.SharedExpertMLP.forward',
zero_memory_shared_expert_mlp_forward)