from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MoEFixRouterFeature(MindSpeedFeature):
def __init__(self):
super().__init__('fix-router', 2)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument("--fix-router", action='store_true', default=False,
help='Enable .')
def validate_args(self, args):
if args.fix_router and args.expert_model_parallel_size <= 1:
raise AssertionError('when enable fix-router, expert_model_parallel_size must be greater than 1')
def register_patches(self, patch_manager, args):
from mindspeed.core.transformer.moe.moe_utils import topk_softmax_with_capacity
if args.fix_router:
patch_manager.register_patch('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
topk_softmax_with_capacity)