import argparse
import importlib
import sys
import logging as logger
import torch.multiprocessing as mp
import pretrain_gpt
from mindspeed_llm.training.utils import auto_coverage
MODULE_ROOT = "mindspeed_llm.tasks.checkpoint"
def load_plugin(plugin_type, name):
if name == '':
module_name = f"{MODULE_ROOT}.{plugin_type}"
else:
module_name = f"{MODULE_ROOT}.{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = f"{MODULE_ROOT}.{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
logger.info(f"Loaded {module_name} as the {plugin_type}.")
return plugin
@auto_coverage
def main():
logger.warning(
"This version of weight conversion tool is approaching end-of-life. "
"It will be officially deprecated in the Q4 new release. "
"Please look forward to the Weight Conversion V2 version!"
)
parser = argparse.ArgumentParser(
description="Megatron Checkpoint Utility Arguments", allow_abbrev=False, conflict_handler='resolve'
)
parser.add_argument('--model-type', type=str, required=True, choices=['GPT', 'BERT'], help='Type of the model')
parser.add_argument(
'--loader', type=str, default='megatron', help='Module name to load checkpoint, should be on python path'
)
parser.add_argument(
'--load-model-type',
type=str,
nargs='?',
default=None,
const=None,
choices=['hf', 'mg'],
help='Module name to load checkpoint, should be on python path',
)
parser.add_argument(
'--saver', type=str, default='megatron', help='Module name to save checkpoint, should be on python path'
)
parser.add_argument('--load-dir', type=str, required=True, help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True, help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50, help='Maximum number of tensors in the queue')
parser.add_argument(
'--no-checking',
action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking',
)
parser.add_argument(
'--spec',
type=str,
default=None,
nargs='*',
help='Specify the <module_location function_name> pair '
'that returns a spec to customize transformer layer, depending on the use case.',
)
parser.add_argument(
'--model-type-hf',
type=str,
default="llama2",
choices=[
'baichuan',
'baichuan2',
'llama2',
'mixtral',
'chatglm3',
'gemma',
'gemma2',
'qwen3',
'bloom',
'bloom_3b',
'qwen',
'internlm2',
'deepseek2',
'minicpm',
'minicpm3',
'minicpm-moe',
'deepseek2-lite',
'qwen2-moe',
'qwen3-moe',
'phi3.5',
'phi3.5-moe',
'hunyuan',
'glm4',
'seed-oss',
'magistral',
'plm',
],
help='model type of huggingface',
)
parser.add_argument(
'--ckpt-cfg-path',
type=str,
default="configs/checkpoint/model_cfg.json",
help="Path to the config directory. If not specified, the default path in the repository will be used.",
)
parser.add_argument('--qlora-nf4', action='store_true', help='use bitsandbytes nf4 to quantize model.')
parser.add_argument(
'--save-lora-to-hf', action='store_true', default=False, help='Enable only save lora-checkpoint to hf'
)
parser.add_argument(
'--load-checkpoint-loosely', action='store_true', default=False, help='Enable loading checkpoint not strictly.'
)
parser.add_argument(
'--ckpt-format', default='torch', choices=['torch', 'torch_dist', 'zarr'], help='Checkpoint format to use.'
)
parser.add_argument('--lora-target-modules', nargs='+', type=str, default=[], help='Lora target modules.')
known_args, _ = parser.parse_known_args()
use_saver = known_args.load_model_type is None
if use_saver:
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
else:
loader = load_plugin('loader', known_args.load_model_type)
saver = load_plugin('saver', '')
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
model_provider = pretrain_gpt.model_provider
logger.info("Starting saver...")
saver_proc = mp.Process(target=saver.save_model_checkpoint, args=(model_provider, queue, args))
saver_proc.start()
logger.info("Starting loader...")
loader.load_checkpoint(model_provider, queue, args)
logger.info("Waiting for saver to complete...")
saver_proc.join()
if saver_proc.exitcode is not None and saver_proc.exitcode != 0:
logger.error(f"saver process exited with error code {saver_proc.exitcode}")
sys.exit(saver_proc.exitcode)
if __name__ == '__main__':
main()