from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MoETpExtendEpFeature(MindSpeedFeature):
def __init__(self):
super().__init__('moe-tp-extend-ep', 2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument("--moe-tp-extend-ep", action='store_true',
help="use tp group to extend experts parallelism"
"instead of sharding weight tensor of experts in tp group")
group.add_argument("--moe-permutation-async-comm", action='store_true',
help="overlap moe permutation 3 all gather communications")
def validate_args(self, args):
if args.moe_tp_extend_ep:
if args.num_experts % (args.tensor_model_parallel_size * args.expert_model_parallel_size) != 0:
raise AssertionError('`--moe-tp-extend-ep` only support when num_experts % ( tp * ep ) == 0')
if not (args.moe_permutation_async_comm and args.moe_grouped_gemm):
raise AssertionError(
'`--moe-tp-extend-ep` needs `--moe-permutation-async-comm` and `--moe-grouped-gemm`.')
if args.moe_expert_capacity_factor is not None:
raise AssertionError('`--moe-tp-extend-ep` only support when moe_expert_capacity_factor is None.')
def register_patches(self, patch_manager, args):
from mindspeed.core.transformer.moe.moe_feature.adaptor import MindSpeedAlltoAllSEQTptoEpMoELayer
from mindspeed.core.transformer.moe.moe_feature.common import routing_tp_extend_ep
if hasattr(args, 'moe_token_dispatcher_type') and args.moe_token_dispatcher_type == 'alltoall_seq':
if args.moe_tp_extend_ep:
patch_manager.register_patch('megatron.core.transformer.moe.router.TopKRouter.routing',
routing_tp_extend_ep)
if not args.moe_alltoall_overlap_comm:
patch_manager.register_patch(
'megatron.core.transformer.moe.moe_layer.MoELayer',
MindSpeedAlltoAllSEQTptoEpMoELayer)