import os
from argparse import ArgumentParser
from mindspeed.features_manager.feature import MindSpeedFeature

MODEL_TYPE_HF_CHOICES = [
    'qwen3',
    'qwen3-moe',
    'deepseek3',
    'deepseek4',
    'glm45-air',
    'glm45',
    'bailing_mini',
    'qwen3-next',
    'seed-oss',
    'deepseek32',
    'magistral',
    'deepseek2-lite',
    'phi3.5',
    'mamba2',
]


class CheckpointFeature(MindSpeedFeature):
    def __init__(self):
        super().__init__(feature_name="ckeckpoint", optimization_level=0)

    def register_args(self, parser: ArgumentParser):
        group = parser.add_argument_group(title=self.feature_name)

        group.add_argument(
            '--model-type-hf',
            type=str,
            default=None,
            choices=[
                'qwen3',
                'qwen3-moe',
                'deepseek3',
                'deepseek4',
                'glm45-air',
                'glm45',
                'bailing_mini',
                'qwen3-next',
                'seed-oss',
                'baichuan',
                'baichuan2',
                'llama2',
                'mixtral',
                'chatglm3',
                'gemma',
                'gemma2',
                'deepseek4_base',
                'bloom',
                'bloom_3b',
                'qwen',
                'internlm2',
                'deepseek2',
                'minicpm',
                'minicpm3',
                'minicpm-moe',
                'deepseek2-lite',
                'qwen2-moe',
                'phi3.5',
                'phi3.5-moe',
                'hunyuan',
                'glm4',
                'magistral',
                'deepseek32',
                'mamba2',
                'plm',
                'longcat',
                'glm5',
            ],
            help='model type of huggingface',
        )
        group.add_argument(
            '--enable-hf2mg-convert',
            action='store_true',
            help='Enable HuggingFace→Megatron weight conversion and patch. '
            'If set, weight conversion will run automatically during initialize_megatron().',
        )
        group.add_argument(
            '--enable-mg2hf-convert',
            action='store_true',
            help='Enable Megatron→HuggingFac weight after save megatron checkpoint every save iteration. '
            'If set, weight conversion will run automatically after save megatron checkpoint.',
        )
        group.add_argument(
            '--only-convert-last-checkpoint',
            action='store_true',
            help='If set, Megatron→HuggingFace weight conversion will only run automatically after train instead of every save iteration.',
        )
        group.add_argument('--mg-save-dir', type=str, default=None, help='Directory to save megatron checkpoint to')
        group.add_argument('--hf-save-dir', type=str, default=None, help='Directory to save huggingface checkpoint to')
        group.add_argument('--hf-cfg-dir', type=str, default=None, help='Directory to load huggingface config files')
        group.add_argument(
            '--save-layer-by-layer',
            action='store_true',
            default=False,
            help='Enable layer-by-layer saving to avoid OOM when the product of TP and EP is high',
        )
        group.add_argument('--save-lora-to-hf', action='store_true', help='only save lora ckpt to hf.')

    def register_patches(self, patch_manager, args):
        from mindspeed_llm.training.initialize import initialize_megatron_wrapper

        patch_manager.register_patch("megatron.training.initialize.initialize_megatron", initialize_megatron_wrapper)

    def validate_args(self, args):
        if hasattr(args, 'load_model_type'):
            return

        if getattr(args, 'ckpt_format', None) == 'torch_dist' and getattr(args, 'enable_hf2mg_convert', False):
            raise AssertionError('--ckpt-format torch_dist cannot be used together with --enable-hf2mg-convert')

        has_valid_lora_target = hasattr(args, 'lora_target_modules') and args.lora_target_modules

        def has_safetensor_weights(dir_path) -> bool:
            '''
            check if find any safetensor in load dir
            '''
            if not dir_path:
                return False
            if not os.path.isdir(dir_path):
                return False
            for name in os.listdir(dir_path):
                if name.endswith(".safetensors"):
                    return True
                if name.startswith("pytorch_model") and name.endswith(".bin"):
                    return True
            return False

        enable_hf_train = has_safetensor_weights(args.load)

        if not enable_hf_train and args.enable_hf2mg_convert:
            raise AssertionError('cannot find safetensor, please check load dir')

        if enable_hf_train and not args.enable_hf2mg_convert and not args.enable_mg2hf_convert:
            args.enable_hf2mg_convert = True
            args.enable_mg2hf_convert = True

        if not args.load and args.enable_hf2mg_convert:
            raise AssertionError('if enable_hf2mg_convert, please set load dir')

        if not args.save and args.enable_mg2hf_convert:
            raise AssertionError('if enable_mg2hf_convert, please set save dir')

        if has_valid_lora_target and args.enable_mg2hf_convert:
            raise AssertionError('Lora and QLora are not supported with enable_mg2hf_convert')

        if args.enable_hf2mg_convert and args.enable_mg2hf_convert and not args.hf_cfg_dir:
            args.hf_cfg_dir = args.load

        if args.enable_hf2mg_convert and not args.model_type_hf:
            from mindspeed_llm.training.utils import infer_model_type_from_hf_config

            config_path = os.path.join(args.load, 'config.json')
            args.model_type_hf = infer_model_type_from_hf_config(config_path, MODEL_TYPE_HF_CHOICES)