"""Convert checkpoint from mindspore"""
import argparse
import collections
import torch
import mindspore as ms
from mindformers.utils.convert_utils import ms2pt
ms_name = [
"backbone.blocks.{}.layernorm1.gamma",
"backbone.blocks.{}.layernorm1.beta",
"backbone.blocks.{}.layernorm2.gamma",
"backbone.blocks.{}.layernorm2.beta",
"backbone.blocks.{}.attention.projection.weight",
"backbone.blocks.{}.attention.projection.bias",
"backbone.blocks.{}.attention.dense1.weight",
"backbone.blocks.{}.attention.dense1.bias",
"backbone.blocks.{}.attention.dense2.weight",
"backbone.blocks.{}.attention.dense2.bias",
"backbone.blocks.{}.attention.dense3.weight",
"backbone.blocks.{}.attention.dense3.bias",
"backbone.blocks.{}.output.mapping.weight",
"backbone.blocks.{}.output.mapping.bias",
"backbone.blocks.{}.output.projection.weight",
"backbone.blocks.{}.output.projection.bias",
]
torch_name = [
"transformer.h.{}.ln_1.weight",
"transformer.h.{}.ln_1.bias",
"transformer.h.{}.ln_2.weight",
"transformer.h.{}.ln_2.bias",
"transformer.h.{}.attn.c_proj.weight",
"transformer.h.{}.attn.c_proj.bias",
"transformer.h.{}.attn.c_attn.weight.q",
"transformer.h.{}.attn.c_attn.bias.q",
"transformer.h.{}.attn.c_attn.weight.k",
"transformer.h.{}.attn.c_attn.bias.k",
"transformer.h.{}.attn.c_attn.weight.v",
"transformer.h.{}.attn.c_attn.bias.v",
"transformer.h.{}.mlp.c_fc.weight",
"transformer.h.{}.mlp.c_fc.bias",
"transformer.h.{}.mlp.c_proj.weight",
"transformer.h.{}.mlp.c_proj.bias"
]
addition_mindspore = [
"backbone.layernorm.gamma",
"backbone.layernorm.beta",
"backbone.embedding.word_embedding.embedding_table",
"backbone.embedding.position_embedding.embedding_table",
"head.head_weight",
]
addition_torch = [
"transformer.ln_f.weight",
"transformer.ln_f.bias",
"transformer.wte.weight",
"transformer.wpe.weight",
"lm_head.weight",
]
def generate_weight_map(total_layers,
mindspore_params_per_layer,
torch_params_per_layer,
mindspore_additional_params,
torch_additional_params):
"""
generate weight map
"""
map_dict = dict(zip(mindspore_additional_params, torch_additional_params))
for i in range(total_layers):
for idx, ms_para in enumerate(mindspore_params_per_layer):
map_dict[ms_para.format(i)] = torch_params_per_layer[idx].format(i)
return map_dict
def convert_ms_to_pt(input_path, output_path, dtype=None, **kwargs):
"""
convert ms to pt
"""
state_dict = {}
print(f"Trying to convert mindspore checkpoint in {input_path}.")
model_ms = ms.load_checkpoint(input_path)
assert len(ms_name) == len(torch_name)
assert len(addition_mindspore) == len(addition_torch)
total_layers, flag = divmod(len(model_ms) - len(addition_mindspore), len(ms_name))
if flag:
raise Exception("The weight names don't match.")
weight_map = generate_weight_map(total_layers, ms_name, torch_name, addition_mindspore, addition_torch)
attention_dict = collections.defaultdict(lambda: {})
for name, value in model_ms.items():
value = ms2pt(value, dtype)
if name.endswith('weight') and ('mapping' in name or 'projection' in name):
value = value.transpose(0, 1)
name = weight_map[name]
if name.endswith('.q'):
name = name.rstrip('.q')
attention_dict[name]['q'] = value
continue
if name.endswith('.k'):
name = name.rstrip('.k')
attention_dict[name]['k'] = value
continue
if name.endswith('.v'):
name = name.rstrip('.v')
attention_dict[name]['v'] = value
continue
state_dict[name] = value
for name, value_dict in attention_dict.items():
state_dict[name] = torch.cat((value_dict['q'], value_dict['k'], value_dict['v']), 0)
torch.save(state_dict, output_path)
print(f"Convert finished, the output is saved to {output_path}.")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="WizardCoder convert script")
parser.add_argument("--mindspore_path",
type=str,
default="wizardcoder.ckpt",
help="The input mindspore checkpoint path.")
parser.add_argument("--torch_path",
type=str,
default='wizardcoder.bin',
help="The output torch checkpoint path.")
opt = parser.parse_args()
convert_ms_to_pt(opt.mindspore_path, opt.torch_path)