#  Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.

import argparse

import logging as logger

import time

from mindspeed_llm.tasks.checkpoint.convert_hf2mg import Hf2MgConvert

from mindspeed_llm.tasks.checkpoint.convert_mg2hf import Mg2HfConvert

from mindspeed_llm.tasks.checkpoint.convert_ckpt_mamba2 import MambaConverter

from mindspeed_llm.tasks.checkpoint.convert_ckpt_longcat import LongCatConverter

from mindspeed_llm.tasks.checkpoint.convert_ckpt_deepseek4 import DeepSeek4Converter

from mindspeed_llm.training.utils import auto_coverage





def get_args():

    parser = argparse.ArgumentParser()

    parser.add_argument(

        '--load-model-type',

        type=str,

        nargs='?',

        default='hf',

        const=None,

        choices=['hf', 'mg'],

        help='Type of the converter',

    )

    parser.add_argument('--save-model-type', type=str, default='mg', choices=['mg', 'hf'], help='Save model type')

    parser.add_argument('--load-dir', type=str, required=True, help='Directory to load model checkpoint from')

    parser.add_argument('--save-dir', type=str, required=True, help='Directory to save model checkpoint to')

    parser.add_argument(

        '--model-type-hf',

        type=str,

        default="qwen3",

        choices=[

            'qwen3',

            'qwen3-moe',

            'deepseek3',

            'deepseek4',

            'glm45-air',

            'glm45',

            'bailing_mini',

            'qwen3-next',

            'seed-oss',

            'deepseek32',

            'magistral',

            'deepseek2-lite',

            'phi3.5',

            'mamba2',

            'longcat',

            'glm5',

            'deepseek4_base',

        ],

        help='model type of huggingface',

    )

    parser.add_argument(

        '--target-tensor-parallel-size', type=int, default=1, help='Target tensor model parallel size, defaults to 1.'

    )

    parser.add_argument(

        '--target-pipeline-parallel-size',

        type=int,

        default=1,

        help='Target pipeline model parallel size, defaults to 1.',

    )

    parser.add_argument(

        '--target-expert-parallel-size', type=int, default=1, help='Target expert model parallel size, defaults to 1.'

    )

    parser.add_argument(

        '--expert-tensor-parallel-size',

        type=int,

        default=None,

        help='Degree of expert model parallelism, Currentley it is support to be set to 1 or None. Default is None, which will be set to the value of --target-tensor-parallel-size',

    )

    parser.add_argument(

        '--num-layers-per-virtual-pipeline-stage',

        type=int,

        default=None,

        help='Number of layers per virtual pipeline stage',

    )

    parser.add_argument('--moe-grouped-gemm', action='store_true', help='Use moe grouped gemm.')

    parser.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.')

    parser.add_argument('--mtp-num-layers', type=int, default=0, help='Multi-Token prediction layer num')

    parser.add_argument(

        '--num-layer-list', type=str, help='a list of number of layers, separated by comma; e.g., 4,4,4,4'

    )

    parser.add_argument(

        "--moe-tp-extend-ep",

        action='store_true',

        help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group",

    )

    parser.add_argument(

        '--mla-mm-split', action='store_true', default=False, help='Split 2 up-proj matmul into 4 in MLA'

    )

    parser.add_argument(

        '--schedules-method',

        type=str,

        default=None,

        choices=['dualpipev'],

        help='An innovative bidirectional pipeline parallelism algorithm.',

    )

    parser.add_argument(

        '--first-k-dense-replace', type=int, default=None, help='Customizing the number of dense layers.'

    )

    parser.add_argument('--num-layers', type=int, default=None, help='Specify the number of transformer layers to use.')

    parser.add_argument(

        '--transformer-impl',

        default='local',

        choices=['local', 'transformer_engine'],

        help='Which Transformer implementation to use.',

    )

    parser.add_argument('--hf-cfg-dir', type=str, default=None, help='Directory to load hugging face config files')

    parser.add_argument('--input-tp-rank', type=int, help='Tensor Parallel rank of the input model shard')

    parser.add_argument('--input-pp-rank', type=int, help='Pipeline Parallel rank of the input model shard')

    parser.add_argument('--hidden-size', type=int, default=4096, help='Model dimension (hidden size)')

    parser.add_argument('--mamba-state-dim', type=int, default=128, help='State dimension used in the Mamba model')

    parser.add_argument('--mamba-num-groups', type=int, default=8, help='Number of groups in Mamba v2 model')

    parser.add_argument('--mamba-head-dim', type=int, default=64, help='Head dimension in Mamba v2 model')

    parser.add_argument(

        '--qlora-nf4', action='store_true', default=False, help='use bitsandbytes nf4 to quantize model.'

    )

    parser.add_argument(

        '--save-layer-by-layer',

        action='store_true',

        default=False,

        help='Enable layer-by-layer saving to avoid OOM when the product of TP and EP is high',

    )

    parser.add_argument('--lora-load', type=str, default=None, help='Directory containing the lora model checkpoint.')

    parser.add_argument('--lora-r', type=int, default=None, help='Lora r.')

    parser.add_argument('--lora-alpha', type=int, default=None, help='Lora alpha.')

    parser.add_argument('--lora-target-modules', nargs='+', type=str, default=[], help='Lora target modules.')

    parser.add_argument('--save-lora-to-hf', action='store_true', help='only save lora ckpt to hf.')

    args, unknown = parser.parse_known_args()

    if unknown:

        raise ValueError(f"Unrecognized argument : {' '.join(unknown)}")

    return args





@auto_coverage

def main():

    args = get_args()

    logger.info(f"Arguments: {args}")

    if args.model_type_hf == 'mamba2':

        converter = MambaConverter(args)

    elif args.model_type_hf == 'longcat':

        converter = LongCatConverter(args)

    elif args.model_type_hf in ('deepseek4', 'deepseek4_base'):

        converter = DeepSeek4Converter(args)

    elif args.load_model_type == 'hf' and args.save_model_type == 'mg':

        converter = Hf2MgConvert(args)

    elif args.load_model_type == 'mg' and args.save_model_type == 'hf':

        converter = Mg2HfConvert(args)

    else:

        raise Exception("This conversion scheme is not supported")



    start_time = time.time()

    converter.run()

    end_time = time.time()

    logger.info("time-consuming: {:.2f}s".format(end_time - start_time))





if __name__ == '__main__':

    main()