from mindspeed.features_manager.feature import MindSpeedFeature
class MoERouter(MindSpeedFeature):
def __init__(self):
super(MoERouter, self).__init__(feature_name="moe_router", optimization_level=0)
def register_args(self, parser):
group = parser.add_argument_group(title=self.feature_name)
self.add_parser_argument_choices_value(parser, "--moe-router-score-function", 'sqrtsoftplus')
self.add_parser_argument_choices_value(parser, "--moe-router-load-balancing-type", 'group_limited_greedy')
self.add_parser_argument_choices_value(parser, "--moe-router-load-balancing-type", 'softmax_topk')
self.add_parser_argument_choices_value(parser, "--moe-router-load-balancing-type", 'pai_megatron_aux_loss')
self.add_parser_argument_choices_value(parser, "--moe-router-load-balancing-type", 'sparsemixer_topk')
group.add_argument('--norm-topk-prob', action='store_true', default=False,
help='Normalize the topk weight')
group.add_argument('--seq-aux', action='store_true', default=False,
help='Compute aux loss in seq_aux')
group.add_argument('--moe-device-level-aux-loss-coeff', type=float, default=0.,
help='set the coeff for devicie-level balance loss in deepseek moe')
group.add_argument('--moe-comm-aux-loss-coeff', type=float, default=0.,
help='set the coeff for communication balance loss in deepseek moe')
group.add_argument('--router-gating-in-fp32', action='store_true', default=False,
help='Compute router gating in float32.')
group.add_argument("--moe-revert-type-after-topk", action='store_true',
help="revert the type of logits after the topk has been computed")
group.add_argument("--fix-router", action='store_true',
help="fix router for load balancing.")
group.add_argument("--topk-softmax-in-fp32", action='store_true',
help="topk softmax in fp32.")
group.add_argument('--num-zero-experts', type=int, default=None,
help='Number of Experts in MoE (None means no MoE)')
group.add_argument('--n-hash-layers', type=int, default=0, help='expert hash layer num')
def pre_validate_args(self, args):
self.origin_spec = None
self.origin_spec = args.spec
args.spec = None
def validate_args(self, args):
self._validate_moe_args(args)
self._validate_group_limited_greedy(args)
self._validate_aux_loss_free(args)
def post_validate_args(self, args):
if self.origin_spec:
args.spec = self.origin_spec
def _validate_moe_args(self, args):
from mindspeed_llm.training.utils import print_rank0_by_args
if args.moe_expert_capacity_factor is not None:
if args.moe_token_dispatcher_type == "allgather":
raise ValueError(f'moe_expert_capacity_factor not works with allgather token dispatcher')
if args.moe_expert_capacity_factor < 0:
args.moe_expert_capacity_factor = None
print_rank0_by_args(
f'When moe_expert_capacity_factor < 0, no token would be drop, so moe_expert_capacity_factor should be set to false.')
if args.moe_router_load_balancing_type not in ["aux_loss", "none"]:
raise ValueError(f'moe_expert_capacity_factor only works with aux_loss or none load balancing')
if args.moe_expert_capacity_factor is None and args.moe_pad_expert_input_to_capacity:
raise ValueError(f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity')
if args.shared_expert_gate_output_dimension != 1 and args.shared_expert_gate_output_dimension != args.hidden_size:
raise AssertionError('shared expert gate output dimension can only be configured with 1 or hidden_size')
def _validate_group_limited_greedy(self, args):
if args.moe_router_load_balancing_type == "group_limited_greedy":
if args.moe_router_group_topk is None:
raise AssertionError('The parameter topk-group should be set when use group_limited_greedy.')
elif args.moe_router_topk_scaling_factor is None:
raise AssertionError(
'The parameter --moe-router-topk-scaling-factor should be set when use multi_latent_attention.')
elif args.moe_router_group_topk >= args.expert_model_parallel_size:
raise AssertionError(
'The topk group ({}) should be less than n-group(EP)({}).'.format(args.moe_router_group_topk,
args.expert_model_parallel_size))
def _validate_aux_loss_free(self, args):
SUPPORTED_SCORE_FUNCTIONS = {"sigmoid", "sqrtsoftplus", "softmax"}
if args.moe_router_enable_expert_bias and args.moe_router_score_function not in SUPPORTED_SCORE_FUNCTIONS:
raise ValueError(
f"Expert bias for aux-loss-free routing only supports {SUPPORTED_SCORE_FUNCTIONS}. "
f"Got: '{args.moe_router_score_function}'. "
f"Please set --moe-router-score-function to one of: sigmoid, sqrtsoftplus, softmax."
)
def register_patches(self, patch_manager, args):
from mindspeed_llm.core.transformer.moe.router import (topk_router_routing, topk_router_init_wrapper,
topk_router_gating_func, topk_router_forward_patch, transformer_config_post_init_wrapper)
from mindspeed_llm.core.transformer.moe.moe_utils import z_loss_func, topk_softmax_with_capacity
from mindspeed_llm.core.transformer.moe.moe_layer import moe_layer_forward
from megatron.core.transformer.transformer_config import TransformerConfig
from mindspeed_llm.core.distributed.finalize_model_grads import _update_router_expert_bias_for_patch
from mindspeed_llm.core.transformer.transformer_block import _checkpointed_forward_patch_input_ids
patch_manager.register_patch('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__',
transformer_config_post_init_wrapper)
patch_manager.register_patch('megatron.core.transformer.moe.router.TopKRouter.__init__',
topk_router_init_wrapper)
patch_manager.register_patch('megatron.core.transformer.moe.router.TopKRouter.routing',
topk_router_routing)
patch_manager.register_patch('megatron.core.transformer.moe.router.TopKRouter.forward',
topk_router_forward_patch)
patch_manager.register_patch('megatron.core.transformer.moe.router.TopKRouter.gating',
topk_router_gating_func)
patch_manager.register_patch('megatron.core.transformer.moe.router.z_loss_func',
z_loss_func)
if args.num_zero_experts:
from mindspeed_llm.tasks.models.transformer.zero_experts_moe_layer import zero_experts_moe_layer_init_wrapper, zero_experts_moe_forward
patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.__init__', zero_experts_moe_layer_init_wrapper)
patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.forward', zero_experts_moe_forward)
elif not (args.moe_allgather_overlap_comm or args.moe_alltoall_overlap_comm):
patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer.forward',
moe_layer_forward)
patch_manager.register_patch('megatron.core.distributed.finalize_model_grads._update_router_expert_bias',
_update_router_expert_bias_for_patch)
if args.topk_softmax_in_fp32:
patch_manager.register_patch('megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity',
topk_softmax_with_capacity)
if args.n_hash_layers >= 1:
patch_manager.register_patch('megatron.core.transformer.transformer_block.TransformerBlock._checkpointed_forward',
_checkpointed_forward_patch_input_ids)