from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature


class ModuleFeature(MindSpeedFeature):
    def __init__(self):
        super(ModuleFeature, self).__init__(feature_name="module-feature", optimization_level=0)
    
    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)

        group.add_argument('--embedding-multiplier-scale', type=float, default=1.0,
                           help='add scale for embedding.')
        group.add_argument('--input-jitter', action='store_false',
                           help='Add noise to the input tensor.')
        group.add_argument('--post-norm', action='store_true',
                           help='post norm after attention or mlp.')
        group.add_argument('--output-multiplier-scale', type=float, default=None,
                           help='Add scale for logits output.')
        group.add_argument('--scale-emb', type=float, default=None,
                           help='scale embed tokens')
        group.add_argument('--dim-model-base', type=float, default=None,
                           help='dim-model-base')
        group.add_argument('--gelu-tanh', action='store_true', default=False,
                           help='Tanh Geglu activate function.')
        group.add_argument('--output-logit-softcapping', type=float,
                           help='output logit softcapping.')
        group.add_argument('--attn-logit-softcapping', type=float,
                           help='attention logit softcapping.')
        group.add_argument('--query-pre-attn-scalar', type=int,
                           help='attention scalar.')
        group.add_argument('--add-rmsnorm-offset', action='store_true', default=False,
                           help='RMSNorm unit offset.')
        group.add_argument('--input-embeds-norm', action='store_true', default=False,
                           help='input normalization.')
        group.add_argument("--cla-share-factor", type=int, default=1,
                           help="Cross-Layer Attention share kv between cla-share-factor layers")
        group.add_argument('--share-kvstates', action='store_true',
                           help='CLA share kv states.')
        group.add_argument("--input-layernorm-in-fp32", action='store_true',
                           help="Convert input-layernorm to fp32")
        group.add_argument("--skip-bias-add", action="store_false", default=True,
                           help='Configuration for the skip bias.')
        group.add_argument('--output-layer-slice-num', type=int, default=1,
                       help='Set the number of slices for the weight of the output_layer')
        group.add_argument('--geglu', action='store_true', default=False,
                           help='Geglu activate function.')
        group.add_argument('--no-post-layer-norm', action='store_true', default=False,
                           help='Disable final layer norm.')
        group.add_argument('--rmsnorm-weight-in-fp32', action='store_true', default=False,
                           help='rmsnorm weight in fp32')
        group.add_argument('--no-enable-linear-qkv', action='store_true', default=False,
                           help='no enable linear_qkv')
        group.add_argument('--fc-type', type=str, default=None,
                           help='Specifies the internal structure of the MLP module.')

    def register_patches(self, patch_manager, args):
        from mindspeed_llm.core.models.common.rms_norm import rms_norm_init_wrapper, rms_norm_forward
        patch_manager.register_patch('megatron.legacy.model.rms_norm.RMSNorm.__init__', rms_norm_init_wrapper)
        patch_manager.register_patch('megatron.legacy.model.rms_norm.RMSNorm.forward', rms_norm_forward)