"""Convert checkpoint from torch/huggingface"""
import argparse
import numpy as np
import torch
import mindspore as ms
from mindspore import save_checkpoint, Tensor
from transformers.models.gpt_bigcode import GPTBigCodeForCausalLM
def generate_params_dict(total_layers,
mindspore_params_per_layer,
torch_params_per_layer,
mindspore_additional_params,
torch_additional_params):
"""
Generate the total parameter mapping of mindspore and pytorch.
Args:
total_layers(int): The total layers of the net.
mindspore_params_per_layer(list): The list of params per layer for the net of mindspore.
torch_params_per_layer(list): The list of params per layer for the net of pytorch.
mindspore_additional_params(list): The list of params outside the layer for the net of mindspore
torch_additional_params(list): The list of params outside the layer for the net of pytorch.
Returns:
A list of tuple. The first element is the parameter name of mindspore,
the another is the parameter name of pytorch.
"""
mapped_params = list(zip(mindspore_params_per_layer, torch_params_per_layer))
ms_extend_param_list = []
torch_extend_param_list = []
for i in range(total_layers):
for ms_para, torch_para in mapped_params:
src = ms_para.format(i)
tgt = torch_para.format(i)
ms_extend_param_list.append(src)
torch_extend_param_list.append(tgt)
mapped_params = list(zip(mindspore_additional_params, torch_additional_params))
for ms_para, torch_para in mapped_params:
ms_extend_param_list.append(ms_para)
torch_extend_param_list.append(torch_para)
return list(zip(ms_extend_param_list, torch_extend_param_list))
def print_dict(input_dict):
"""
Print the keys and values of input dict
Args:
input_dict(dict): input dict with key and value.
Returns:
None
"""
for k, v in input_dict.items():
print(f"Param: {k} with shape {v.shape}")
def get_converted_ckpt(mapped_params, weight_dict):
"""
Print the keys of the loaded checkpoint
Args:
mapped_params(dict): The loaded checkpoint. The key is parameter name and value is the numpy array.
weight_dict(dict): The loaded pytorch checkpoint.
Returns:
None
"""
new_ckpt_list = []
for src, tgt in mapped_params:
value = weight_dict[tgt].numpy()
if tgt.endswith('weight') and ('c_proj' in tgt or 'c_fc' in tgt):
print("----Transpose tgt:", tgt)
value = np.transpose(value, [1, 0])
print(f"Mapping table Mindspore:{src:<30} \t Torch:{tgt:<30} with shape {value.shape}")
new_ckpt_list.append({"data": Tensor(value, dtype=ms.float16), "name": src})
return new_ckpt_list
def split_torch_attention(state):
"""
split the torch attention parameter
Args:
state(dict): The loaded state dict. The key is parameter name and value is the numpy array.
Returns:
None
"""
s = list(state.keys())
for name in s:
if name.endswith('attn.c_attn.weight') or name.endswith('attn.c_attn.bias'):
value = state.pop(name)
print("The real value shape is:", value.shape)
q, k, v = np.split(value.numpy(), [6144, 6272], 0)
print("---q shape:", q.shape)
print("---k shape:", k.shape)
print("---v shape:", v.shape)
state[name + '.q'] = torch.tensor(q, dtype=value.dtype)
state[name + '.k'] = torch.tensor(k)
state[name + '.v'] = torch.tensor(v)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="WizardCoder convert script"
"Examples:"
"python research/wizardcoder/convert_weight.py --layers 40 "
"--torch_path /xxx/pytorch_model.bin --mindspore_path /xxx/ms.ckpt")
parser.add_argument('--layers',
type=int,
default=40,
help="The number of layers of the model to be converted.")
parser.add_argument("--torch_path",
type=str,
default='/home/wizardcoder/pytorch_models_60step/',
help="The torch checkpoint path.")
parser.add_argument("--mindspore_path",
type=str,
default="/home/wizardcoder/mindspore_models_rank_60step/rank_0/wizardcoder.ckpt",
help="Use device nums, default is 128.")
opt = parser.parse_args()
device = 'cpu'
small_model_path = opt.torch_path
model = GPTBigCodeForCausalLM.from_pretrained(small_model_path).to(device)
state_dict = model.state_dict()
print_dict(state_dict)
ms_name = [
"backbone.blocks.{}.layernorm1.gamma",
"backbone.blocks.{}.layernorm1.beta",
"backbone.blocks.{}.layernorm2.gamma",
"backbone.blocks.{}.layernorm2.beta",
"backbone.blocks.{}.attention.projection.weight",
"backbone.blocks.{}.attention.projection.bias",
"backbone.blocks.{}.attention.dense1.weight",
"backbone.blocks.{}.attention.dense1.bias",
"backbone.blocks.{}.attention.dense2.weight",
"backbone.blocks.{}.attention.dense2.bias",
"backbone.blocks.{}.attention.dense3.weight",
"backbone.blocks.{}.attention.dense3.bias",
"backbone.blocks.{}.output.mapping.weight",
"backbone.blocks.{}.output.mapping.bias",
"backbone.blocks.{}.output.projection.weight",
"backbone.blocks.{}.output.projection.bias",
]
torch_name = [
"transformer.h.{}.ln_1.weight",
"transformer.h.{}.ln_1.bias",
"transformer.h.{}.ln_2.weight",
"transformer.h.{}.ln_2.bias",
"transformer.h.{}.attn.c_proj.weight",
"transformer.h.{}.attn.c_proj.bias",
"transformer.h.{}.attn.c_attn.weight.q",
"transformer.h.{}.attn.c_attn.bias.q",
"transformer.h.{}.attn.c_attn.weight.k",
"transformer.h.{}.attn.c_attn.bias.k",
"transformer.h.{}.attn.c_attn.weight.v",
"transformer.h.{}.attn.c_attn.bias.v",
"transformer.h.{}.mlp.c_fc.weight",
"transformer.h.{}.mlp.c_fc.bias",
"transformer.h.{}.mlp.c_proj.weight",
"transformer.h.{}.mlp.c_proj.bias"
]
addition_mindspore = [
"backbone.layernorm.gamma",
"backbone.layernorm.beta",
"backbone.embedding.word_embedding.embedding_table",
"backbone.embedding.position_embedding.embedding_table",
"head.head_weight",
]
addition_torch = [
"transformer.ln_f.weight",
"transformer.ln_f.bias",
"transformer.wte.weight",
"transformer.wpe.weight",
"lm_head.weight",
]
mapped_param = generate_params_dict(total_layers=opt.layers,
mindspore_params_per_layer=ms_name,
torch_params_per_layer=torch_name,
mindspore_additional_params=addition_mindspore,
torch_additional_params=addition_torch)
split_torch_attention(state_dict)
new_ckpt = get_converted_ckpt(mapped_param, state_dict)
save_checkpoint(new_ckpt, opt.mindspore_path)
print(f"Convert finished, the output is saved to {opt.mindspore_path}")