import torch
import torch.distributed as dist
from megatron.core import mpu
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import slime.backends.megatron_utils as megatron_utils
from slime.backends.megatron_utils import update_weight_utils
from slime.utils.arguments import parse_args
def add_checkpoint_args(parser):
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Directory to save the converted HF model.",
)
parser.add_argument(
"--check-same",
action="store_true",
default=False,
help="Check if the converted model is the same as the original model.",
)
return parser
def main(args):
megatron_utils.init(args)
pp_size = mpu.get_pipeline_model_parallel_world_size()
ep_size = mpu.get_expert_model_parallel_world_size()
is_save_rank = (
mpu.get_data_parallel_rank(with_context_parallel=True) == 0 and mpu.get_tensor_model_parallel_rank() == 0
)
args.no_load_optim = True
args.no_load_rng = True
model, _, _, _ = megatron_utils.initialize_model_and_optimizer(args)
hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
model_name = type(hf_config).__name__.lower()
tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
vocab_size = tokenizer.vocab_size if args.vocab_size is None else args.vocab_size
param_infos = update_weight_utils.get_param_infos(args, model)
state_dict = {}
rank = dist.get_rank()
for info in param_infos:
if dist.get_rank() == info.src_rank:
for name_, param_ in update_weight_utils.named_parameters(args, model):
if name_ == info.name:
param = param_
break
else:
param = torch.empty(info.shape, dtype=info.dtype, device=torch.cuda.current_device())
if pp_size > 1:
if info.src_rank in dist.get_process_group_ranks(mpu.get_pipeline_model_parallel_group()):
torch.distributed.broadcast(param, src=info.src_rank, group=mpu.get_pipeline_model_parallel_group())
if ep_size > 1:
if ".experts." in info.name:
src_rank = (
info.src_rank
if info.src_rank in dist.get_process_group_ranks(mpu.get_expert_model_parallel_group())
else rank
)
torch.distributed.broadcast(param, src=src_rank, group=mpu.get_expert_model_parallel_group())
for key, value in info.attrs.items():
setattr(param, key, value)
param = update_weight_utils.all_gather_param(info.name, param)
param = update_weight_utils.remove_padding(info.name, param, vocab_size)
if is_save_rank:
converted_named_tensors = update_weight_utils.convert_to_hf(args, model_name, info.name, param)
for name, param in converted_named_tensors:
state_dict[name] = param.cpu()
del param
if is_save_rank:
hf_model = AutoModelForCausalLM.from_pretrained(
args.hf_checkpoint, torch_dtype="auto", device_map="cpu", trust_remote_code=True
)
if args.check_same:
for name, param in hf_model.named_parameters():
if name in state_dict:
assert (
param.shape == state_dict[name].shape
), f"Shape mismatch for {name}: {param.shape} vs {state_dict[name].shape}"
assert torch.all(param == state_dict[name]), f"Value mismatch for {name}"
else:
print(f"Warning: {name} not found in state_dict")
if args.output_dir:
tokenizer.save_pretrained(args.output_dir)
print(hf_model.load_state_dict(state_dict, strict=False))
hf_model.save_pretrained(args.output_dir)
dist.barrier()
if __name__ == "__main__":
args = parse_args(add_custom_arguments=add_checkpoint_args)
main(args)