from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature


class MambaModel(MindSpeedFeature):
    def __init__(self):
        super(MambaModel, self).__init__(feature_name="mamba", optimization_level=0)
    
    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)

        group.add_argument('--mamba-d-ssm', type=int, default=None,
                            help='If not None, only apply SSM on this many dimensions, the rest uses gated MLP')
        group.add_argument('--mamba-chunk-size', type=int, default=256, 
                            help='Split the chunk size of tensor in mamba')  
        group.add_argument('--mamba-d-conv', type=int, default=4, 
                            help='conv channel dim for mamba')  
        group.add_argument('--mamba-expand', type=int, default=1, 
                            help='expand scale for mamba')  

    def register_patches(self, patch_manager, args):
        from mindspeed_llm.core.ssm.mamba_mixer import mamba_mixer_init_wrapper, mamba_mixer_forward, Mamba2RMSNorm
        from mindspeed_llm.core.ssm.mamba_block import mamba_block_forward

        patch_manager.register_patch('mamba_ssm.ops.triton.layernorm_gated.RMSNorm', 
                                      Mamba2RMSNorm, create_dummy=True)
        patch_manager.register_patch('mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined',
                                      create_dummy=True)
        patch_manager.register_patch('mamba_ssm.ops.triton.ssd_combined.mamba_split_conv1d_scan_combined',
                                      create_dummy=True)

        patch_manager.register_patch('megatron.core.ssm.mamba_mixer.MambaMixer.__init__',
                                      mamba_mixer_init_wrapper)
        patch_manager.register_patch('megatron.core.ssm.mamba_mixer.MambaMixer.forward',
                                      mamba_mixer_forward)
        patch_manager.register_patch('megatron.core.ssm.mamba_block.MambaStack.forward',
                                      mamba_block_forward)