from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
class MoESharedExpertsFeature(MindSpeedFeature):
def __init__(self):
super().__init__('n-shared-experts', 0)
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--n-shared-experts', type=int, default=None,
help='DEPRECATED. use --moe-shared-expert-intermediate-size replace')
def validate_args(self, args):
if args.n_shared_experts and args.moe_shared_expert_intermediate_size and (args.moe_shared_expert_intermediate_size != args.n_shared_experts * (
args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size)):
raise AssertionError('`n_shared_experts` cannot be used with `moe_shared_expert_intermediate_size` together. Please use one of them.')
if args.n_shared_experts and args.moe_shared_expert_intermediate_size is None:
args.moe_shared_expert_intermediate_size = args.n_shared_experts * (
args.moe_ffn_hidden_size if args.moe_ffn_hidden_size is not None else args.ffn_hidden_size)