"""Convert checkpoint from huggingface"""
import os
import re
import argparse
import torch
import mindspore
from mindspore import Tensor, Parameter
def layer_name_mapping(model_name, key):
"""Convert huggingface PP weights mapping in MindSpore.
return: new_name
"""
prefix = ''
layer_rename_map = {
"word_embeddings.weight": "model.tok_embeddings.embedding_weight",
"input_layernorm.weight": "attention_norm.weight",
"self_attention.dense.weight": "attention.wo.weight",
"self_attention.dense.bias": "attention.wo.bias",
"self_attention.query.weight": "attention.wq.weight",
"self_attention.key_value.weight": "attention.wk_v.weight",
"mlp.gate_proj.weight": "feed_forward.w1.weight",
"mlp.down_proj.weight": "feed_forward.w2.weight",
"mlp.down_proj.bias": "feed_forward.w2.bias",
"mlp.up_proj.weight": "feed_forward.w3.weight",
"post_attention_layernorm.weight": "ffn_norm.weight",
"ln_f.weight": "model.norm_out.weight"
}
if model_name == "telechat_12b":
del layer_rename_map["word_embeddings.weight"]
del layer_rename_map["ln_f.weight"]
layer_rename_map["lm_head.weight"] = "lm_head.weight"
layer_rename_map["transformer.word_embeddings.weight"] = "model.tok_embeddings.embedding_weight"
layer_rename_map["transformer.ln_f.weight"] = "model.norm_out.weight"
if key in layer_rename_map:
return prefix + layer_rename_map[key]
match = re.match(r'^\w+\.\w*\.(\d+)\.(\w+\.\w+\.\w+|\w+\.\w+)$', key)
layer_number = int(match.group(1))
text = match.group(2)
return f"{prefix}model.layers.{layer_number}." + layer_rename_map[text]
def hf_to_ms(hf_weights, model_name, ms_dtype=mindspore.float16, for_save=False):
"""Convert hf layers to ms."""
ms_params = {}
for k, v in hf_weights.items():
new_name = layer_name_mapping(model_name, k)
new_tensor = Tensor(v.float().detach().numpy(), ms_dtype)
ms_params[new_name] = Parameter(new_tensor, name=new_name)
if for_save:
return [{'name': k, 'data': v} for k, v in ms_params.items()]
return ms_params
def process_shard_files(files, config, ms_dtype=mindspore.float16):
''' torch ckpt files loop'''
if config.mindspore_path and not os.path.exists(args.mindspore_path):
os.makedirs(config.mindspore_path, exist_ok=True)
ms_file_name = "mindspore"
combine_params = []
for per_file in files:
pt_states = torch.load(per_file, map_location='cpu')
ms_params = hf_to_ms(pt_states, config.model_name, ms_dtype, True)
combine_params.extend(ms_params)
del ms_params
save_file = config.mindspore_path + '/' + ms_file_name + '.ckpt'
mindspore.save_checkpoint(combine_params, save_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Telechat convert script")
parser.add_argument("--torch_path",
type=str,
default="",
help="The input torch checkpoint path.")
parser.add_argument("--mindspore_path",
type=str,
default="",
help="The output mindspore checkpoint path.")
parser.add_argument("--model_name",
type=str,
default="telechat_12b",
help="The name of model, supports name in {'telechat_7b', 'telechat_12b'}")
args = parser.parse_args()
files_list = []
for file_name in os.listdir(args.torch_path):
if file_name.startswith("pytorch_model") and file_name.endswith(".bin"):
files_list.append(os.path.join(args.torch_path, file_name))
process_shard_files(files=files_list, config=args)
current_path = os.getcwd()
mindspore_ckpt_path = os.path.join(current_path, args.mindspore_path)
print("*** finish torch convert ms model, ms_ckpt save in {} ***".format(mindspore_ckpt_path))