"""
Convert Telechat weight.
Support huggingface format.
"""
import os
import re
import argparse
from glob import glob
from safetensors.torch import load_file
import numpy as np
import torch
import mindspore as ms
from mindformers.tools.utils import str2bool
from mindformers.tools import logger
from mindformers.utils.convert_utils import pt2ms
dtype_map = {
'float32': ms.float32,
'bfloat16': ms.bfloat16,
'float16': ms.float16
}
def name_replace(name: str):
"""replace hf param name to ms."""
name = name.replace('transformer.word_embeddings.weight', 'model.tok_embeddings.embedding_weight')
name = name.replace('.input_layernorm', '.attention_norm')
name = name.replace('.self_attention.dense.', '.attention.wo.')
name = name.replace('.self_attention.dense.bias.', '.attention.wo.bias.')
name = name.replace('.self_attention.query.', '.attention.wq.')
name = name.replace('.self_attention.key_value.', '.attention.wk_v.')
name = name.replace('.mlp.gate_proj.', '.feed_forward.w1.')
name = name.replace('.mlp.down_proj.', '.feed_forward.w2.')
name = name.replace('.mlp.down_proj.bias.', '.feed_forward.w2.bias.')
name = name.replace('.mlp.up_proj.', '.feed_forward.w3.')
name = name.replace('.mlp.router.', '.feed_forward.router.dense.')
name = name.replace('.post_attention_layernorm.', '.ffn_norm.')
name = name.replace('lm_head.', 'lm_head.')
name = name.replace('transformer.ln_f.', 'model.norm_out.')
expert_id = extract_expert_id(name)
if expert_id is not None:
name = name.replace(f'.mlp.local_experts.{expert_id}.gate_proj.', f'.feed_forward.ffn.{expert_id}.w1.')
name = name.replace(f'.mlp.local_experts.{expert_id}.up_proj.', f'.feed_forward.ffn.{expert_id}.w3.')
name = name.replace(f'.mlp.local_experts.{expert_id}.down_proj.', f'.feed_forward.ffn.{expert_id}.w2.')
return name, expert_id
def extract_expert_id(layer_name):
expert_pattern = r'local_experts\.(\d+)\.'
expert_match = re.search(expert_pattern, layer_name)
expert_id = int(expert_match.group(1)) if expert_match else None
return expert_id
def sort_dict_by_indices(d):
pattern = r'model\.layers\.(\d+)\.feed_forward\.ffn\.(\d+)\.w'
sorted_keys = sorted(d.keys(), \
key=lambda x: tuple(int(i) for i in re.search(pattern, x).groups()))
return {k: d[k] for k in sorted_keys}
def remove_expert_id(layer_name):
return re.sub(r'ffn\.\d+\.', 'ffn.', layer_name)
def convert_pt_to_ms(input_path, output_path, dtype=None, **kwargs):
"""convert telechat hf weight to ms."""
files = list(glob(os.path.join(input_path, "pytorch_model*.bin")))
convert_safetensors = False
if not files:
files = list(glob(os.path.join(input_path, "model*.safetensors")))
if not files:
raise FileNotFoundError(f"No bin or safetensors found in the model path: {input_path}.")
convert_safetensors = True
files.sort()
pt_states_list = []
for per_file in files:
if convert_safetensors:
pt_states = load_file(per_file)
else:
pt_states = torch.load(per_file, map_location='cpu', weights_only=True)
pt_states_list.append(pt_states)
ckpt_list = []
expert_dict = {}
expert_ids = set()
for pt_states in pt_states_list:
for name, value in pt_states.items():
name, expert_id = name_replace(name)
if name.startswith('transformer.h.'):
name = name.replace('transformer.h.', 'model.layers.')
if expert_id is not None:
expert_dict[name] = value
expert_ids.add(expert_id)
else:
logger.info(f'\rprocessing parameter: {name} {value.shape}')
ckpt_list.append({'name': name, 'data': pt2ms(value, dtype)})
if expert_dict:
expert_dict = sort_dict_by_indices(expert_dict)
expert_name_list = list(expert_dict.keys())
expert_merged_dict = {}
for expert_name in expert_name_list:
name = remove_expert_id(expert_name)
value = expert_dict[expert_name]
if name in expert_merged_dict:
expert_merged_dict[name].append(value)
else:
expert_merged_dict[name] = [value]
del expert_dict[expert_name]
for name in expert_merged_dict:
value = torch.stack(expert_merged_dict[name])
if "bias" in name:
value = value.unsqueeze(0).unsqueeze(2)
logger.info(f'\rprocessing parameter: {name} {value.shape}')
ckpt_list.append({'name': name, 'data': pt2ms(value, dtype)})
ms.save_checkpoint(ckpt_list, output_path)
logger.info(f"\rConvert huggingface checkpoint finished, the mindspore checkpoint is saved in '{output_path}'.")
def transformer_new_kv_param(n_kv_heads, w_kv_weight, mp):
"""transform new kv param."""
head_dim = 128
size = n_kv_heads // mp * 2
split_size_or_sections = size * [head_dim]
print("w_kv_weight split_size_or_sections size {}, {}:".format(len(split_size_or_sections),
split_size_or_sections))
new_kv_weight = ms.ops.split(w_kv_weight, split_size_or_sections, 0)
new_k_weight = []
new_v_weight = []
for i in range(0, size):
if i % 2 == 0:
new_k_weight.append(new_kv_weight[i].asnumpy())
else:
new_v_weight.append(new_kv_weight[i].asnumpy())
new_k_weight = np.concatenate(new_k_weight, 0)
new_v_weight = np.concatenate(new_v_weight, 0)
return new_k_weight, new_v_weight
def convert_qkv_concat_weight(n_kv_heads, num_layers, param_dict, mp=1):
"""convert qkv concat weight"""
for i in range(num_layers):
wq_weight_name = f"model.layers.{i}.attention.wq.weight"
w_kv_weight_name = f"model.layers.{i}.attention.wk_v.weight"
qkv_concat_weight_name = f"model.layers.{i}.attention.w_qkv.weight"
if wq_weight_name not in param_dict:
break
wq_weight = param_dict[wq_weight_name]
w_kv_weight = param_dict[w_kv_weight_name]
new_k_weight, new_v_weight = transformer_new_kv_param(n_kv_heads, w_kv_weight, mp)
qkv_weight = np.concatenate((wq_weight.asnumpy(), new_k_weight, new_v_weight), 0)
param_dict[qkv_concat_weight_name] = ms.Parameter(qkv_weight, name=qkv_concat_weight_name)
ffn_gate_weight_name = f"model.layers.{i}.feed_forward.w1.weight"
ffn_hidden_weight_name = f"model.layers.{i}.feed_forward.w3.weight"
gate_hidden_concat_weight_name = f"model.layers.{i}.feed_forward.w_gate_hidden.weight"
ffn_gate_weight = param_dict[ffn_gate_weight_name]
ffn_hidden_weight = param_dict[ffn_hidden_weight_name]
gate_hidden_weight = np.concatenate((ffn_gate_weight.asnumpy(), ffn_hidden_weight.asnumpy()), 0)
param_dict[gate_hidden_concat_weight_name] = ms.Parameter(gate_hidden_weight,
name=gate_hidden_concat_weight_name)
param_dict.pop(wq_weight_name)
param_dict.pop(w_kv_weight_name)
param_dict.pop(ffn_gate_weight_name)
param_dict.pop(ffn_hidden_weight_name)
print("transform: {}".format(qkv_concat_weight_name))
print("transform: {}".format(gate_hidden_concat_weight_name))
return param_dict
def convert_to_qkv_concat(model_name, pre_ckpt_path, mindspore_ckpt_path):
"""convert previous ckpt to qkv concat ckpt"""
if model_name == "telechat_7B":
n_kv_heads = 32
num_layers = 30
elif model_name == "telechat_35B":
n_kv_heads = 48
num_layers = 64
elif model_name == "telechat_115B":
n_kv_heads = 8
num_layers = 96
else:
raise ValueError("model_name:{} is not supported.".format(model_name))
if os.path.isdir(pre_ckpt_path):
rank_dir_list = os.listdir(pre_ckpt_path)
rank_dir_list_new = []
for rank_dir in rank_dir_list:
if rank_dir.startswith('rank_'):
rank_dir_list_new.append(rank_dir)
rank_dir_list = rank_dir_list_new
mp = len(rank_dir_list)
for rank_dir in rank_dir_list:
rank_dir_name = str(rank_dir)
rank_id = rank_dir_name.split("_")[1]
checkpoint_path = os.path.join(pre_ckpt_path, rank_dir_name, "checkpoint_{}.ckpt".format(rank_id))
print("checkpoint_path: {}".format(checkpoint_path))
params = ms.load_checkpoint(checkpoint_path)
params = convert_qkv_concat_weight(n_kv_heads, num_layers, params, mp)
save_dir = os.path.join(mindspore_ckpt_path, rank_dir_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(mindspore_ckpt_path, rank_dir_name, "checkpoint_{}.ckpt".format(rank_id))
ms.save_checkpoint(params, save_path)
else:
params = ms.load_checkpoint(pre_ckpt_path)
params = convert_qkv_concat_weight(n_kv_heads, num_layers, params)
ms.save_checkpoint(params, mindspore_ckpt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Telechat convert script")
parser.add_argument("--torch_path",
type=str,
default="",
help="The input torch checkpoint path.")
parser.add_argument("--mindspore_path",
type=str,
default="",
help="The output mindspore checkpoint path.")
parser.add_argument("--dtype", default='float32', choices=['float16', 'float32', 'bfloat16'],
help="Data type for output checkpoint file. Default: float16")
parser.add_argument('--qkv_concat', default=False, type=str2bool)
parser.add_argument('--mindspore_ckpt_path', default='transform.ckpt')
parser.add_argument('--pre_ckpt_path', default=None)
parser.add_argument('--model_name', default="telechat_7B", type=str)
args = parser.parse_args()
ms_dtype = dtype_map.get(args.dtype)
if args.qkv_concat:
convert_to_qkv_concat(args.model_name, args.pre_ckpt_path, args.mindspore_ckpt_path)
else:
convert_pt_to_ms(input_path=args.torch_path, output_path=args.mindspore_path, dtype=ms_dtype)