from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.features_manager.fusions.fused_bias_swiglu import FusedSwigluFeature
class SwigluLimitFeature(FusedSwigluFeature):
def __init__(self):
super().__init__()
self.feature_name = 'swiglu-limit'
def register_args(self, parser: ArgumentParser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--swiglu-limit', type=float, default=0,
help='Apply swiglu limit to clamp gate and up values. '
'When > 0, gate is clamped to max=limit and up is clamped to [-limit, limit]. '
'Default is 0 (no limit).')
def register_patches(self, patch_manager, args):
super().register_patches(patch_manager, args)
if args.swiglu_limit:
from mindspeed_llm.core.fusions.fused_bias_swiglu import fused_swiglu_with_limit
patch_manager.register_patch('mindspeed.core.fusions.fused_bias_swiglu.fused_swiglu', fused_swiglu_with_limit)