"""
Convert llama weight.
Support mindspore format and Meta format.
"""
import json
import os
from collections import defaultdict
from safetensors.torch import save_file
import torch
import mindspore as ms
from mindspore.ops.operations import Cast
from mindformers import MindFormerConfig
from mindformers.tools.logger import logger
from mindformers.tools.utils import set_safe_mode_for_file_or_dir
ms.set_device(device_target='CPU')
cpu_cast = Cast().set_device('CPU')
def get_torch_storage_size(tensor):
"""Get tensor's storage size, requires torch >= 2.1"""
return tensor.untyped_storage().nbytes()
def get_dtype(config):
"""Get compute dtype."""
if config.model.model_config.compute_dtype == "bfloat16":
dtype = torch.bfloat16
elif config.model.model_config.compute_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
return dtype
def name_replace(weight_name: str):
"""replace ms param weight_name to hf."""
weight_name = weight_name.replace('tok_embeddings.embedding_weight', 'embed_tokens.weight')
weight_name = weight_name.replace('model.norm_out.weight', 'model.norm.weight')
weight_name = weight_name.replace(".attention.wq.", ".self_attn.q_proj.")
weight_name = weight_name.replace(".attention.wk.", ".self_attn.k_proj.")
weight_name = weight_name.replace(".attention.wv.", ".self_attn.v_proj.")
weight_name = weight_name.replace(".attention.wo.", ".self_attn.o_proj.")
weight_name = weight_name.replace(".attention_norm.", ".input_layernorm.")
weight_name = weight_name.replace(".ffn_norm.", ".post_attention_layernorm.")
weight_name = weight_name.replace('feed_forward.w1.', 'mlp.gate_proj.')
weight_name = weight_name.replace('feed_forward.w2.', 'mlp.down_proj.')
weight_name = weight_name.replace('feed_forward.w3.', 'mlp.up_proj.')
return weight_name
def _model_preprocess_ms_to_pt(ms_layer_weights, config):
"""Processing weights in prepross module"""
dtype = get_dtype(config)
emb_weight_key = "model.tok_embeddings.embedding_weight"
emb_weight = cpu_cast(ms_layer_weights.get(emb_weight_key), ms.float32).numpy()
emb_weight_key = name_replace(emb_weight_key)
plain_weight_dict = defaultdict()
plain_weight_dict[emb_weight_key] = torch.from_numpy(emb_weight).to(dtype).clone()
return plain_weight_dict
def _model_postprocess_ms_to_pt(ms_layer_weights, config):
"""Processing weights in postpross module"""
dtype = get_dtype(config)
final_norm_key = "model.norm_out.weight"
lm_head_key = "lm_head.weight"
final_norm = cpu_cast(ms_layer_weights.get(final_norm_key), ms.float32).numpy()
if ms_layer_weights.get(lm_head_key, None) is not None:
lm_head = cpu_cast(ms_layer_weights.get(lm_head_key), ms.float32).numpy()
final_norm_key = name_replace(final_norm_key)
lm_head_key = name_replace(lm_head_key)
plain_weight_dict = defaultdict()
plain_weight_dict[final_norm_key] = torch.from_numpy(final_norm).to(dtype).clone()
if ms_layer_weights.get(lm_head_key, None) is not None:
plain_weight_dict[lm_head_key] = torch.from_numpy(lm_head).to(dtype).clone()
return plain_weight_dict
def attention_ms_to_pt(layer_id, ms_layer_weights, config):
"""Processing weights in attention module"""
dtype = get_dtype(config)
attention_weight_dict = defaultdict()
attention_weight_list = [
f"model.layers.{layer_id}.attention.wq.weight",
f"model.layers.{layer_id}.attention.wq.bias",
f"model.layers.{layer_id}.attention.wk.weight",
f"model.layers.{layer_id}.attention.wk.bias",
f"model.layers.{layer_id}.attention.wv.weight",
f"model.layers.{layer_id}.attention.wv.bias",
f"model.layers.{layer_id}.attention.wo.weight",
f"model.layers.{layer_id}.attention_norm.weight",
f"model.layers.{layer_id}.ffn_norm.weight"
]
for name in attention_weight_list:
value = cpu_cast(ms_layer_weights.pop(name), ms.float32).numpy()
name = name_replace(name)
attention_weight_dict[name] = torch.from_numpy(value).to(dtype).clone()
return attention_weight_dict
def feed_forward_ms_to_pt(layer_id, ms_layer_weights, config):
"""Processing weights in feed_forward module"""
dtype = get_dtype(config)
feed_forward_weight_dict = defaultdict()
gate_proj_key = f"model.layers.{layer_id}.feed_forward.w1.weight"
down_proj_key = f"model.layers.{layer_id}.feed_forward.w2.weight"
up_proj_key = f"model.layers.{layer_id}.feed_forward.w3.weight"
gate_proj = cpu_cast(ms_layer_weights.pop(gate_proj_key), ms.float32).numpy()
up_proj = cpu_cast(ms_layer_weights.pop(up_proj_key), ms.float32).numpy()
down_proj = cpu_cast(ms_layer_weights.pop(down_proj_key), ms.float32).numpy()
gate_proj_key = name_replace(gate_proj_key)
up_proj_key = name_replace(up_proj_key)
down_proj_key = name_replace(down_proj_key)
feed_forward_weight_dict[gate_proj_key] = torch.from_numpy(gate_proj).to(dtype).clone()
feed_forward_weight_dict[up_proj_key] = torch.from_numpy(up_proj).to(dtype).clone()
feed_forward_weight_dict[down_proj_key] = torch.from_numpy(down_proj).to(dtype).clone()
return feed_forward_weight_dict
def load_data_ms(file_name):
"""Load safetensors"""
return ms.load_checkpoint(file_name, format="safetensors")
def read_matched_file(layer_st_map, layer_list, is_first, is_last):
"""Load weights into dict for specified layers"""
st_file_list = []
for layer in layer_list:
st_file_list.extend(list(layer_st_map[layer]))
if is_first:
st_file_list.extend(list(layer_st_map["model.tok_embeddings.embedding_weight"]))
if is_last:
st_file_list.extend(list(layer_st_map["model.norm_out.weight"]))
st_file_list.extend(list(layer_st_map["lm_head.weight"]))
st_file_list = list(set(st_file_list))
weights = {}
for st_file in st_file_list:
current_weight = load_data_ms(st_file)
weights.update(current_weight)
return weights
def has_only_one_safetensors_file(file_path):
"""Check whether there is only one safetensors file."""
files = os.listdir(file_path)
safetensors_files = [file for file in files if file.endswith('.safetensors')]
return len(safetensors_files) == 1
def layers_model_file_map(file_path):
"""Get weight-file map"""
layer_st_map = defaultdict(set)
weight_map_file = os.path.join(file_path, "param_name_map.json")
if not os.path.exists(weight_map_file):
if has_only_one_safetensors_file(file_path):
weight_map_file = os.path.join(file_path, "param_name_map.json")
safetensors_file = next((f for f in os.listdir(file_path) if f.endswith('.safetensors')), None)
weight = load_data_ms(os.path.join(file_path, safetensors_file))
param_name_map = {key: "model.safetensors" for key in weight.keys()}
with open(weight_map_file, 'w') as f:
json.dump(param_name_map, f, indent=4)
set_safe_mode_for_file_or_dir(weight_map_file)
else:
raise ValueError(f"Cannot find weight map file in path {file_path}")
with open(weight_map_file) as f:
weights_map = json.load(f)
try:
weights_map = weights_map["weight_map"]
except KeyError:
pass
for weight_key, value in weights_map.items():
if weight_key.startswith("model.layers."):
layer_name = int(weight_key.split('model.layers.')[1].split('.')[0])
layer_st_map[layer_name].add(os.path.join(file_path, value))
else:
layer_st_map[weight_key].add(os.path.join(file_path, value))
return layer_st_map
def ms_ckpt_convertor(input_path, output_path, config):
"""convert ms weight to hf."""
logger.info(">>>>>>>>>>>>>>>>>>>>>>>>>>>> Start convert checkpoint from ckpt format.")
num_layers = config.model.model_config.num_layers
logger.info(f"Trying to convert mindspore checkpoint in {input_path}.", flush=True)
ms_weights = ms.load_checkpoint(input_path)
converted_st_map = defaultdict()
converted_st_map["weight_map"] = defaultdict()
converted_st_map["metadata"] = defaultdict()
total_size = 0
for layer_id in range(num_layers):
pt_layer_weights = defaultdict()
if layer_id == 0:
pt_layer_weights.update(_model_preprocess_ms_to_pt(ms_weights, config))
pt_layer_weights.update(attention_ms_to_pt(layer_id, ms_weights, config))
pt_layer_weights.update(feed_forward_ms_to_pt(layer_id, ms_weights, config))
if layer_id == num_layers - 1:
pt_layer_weights.update(_model_postprocess_ms_to_pt(ms_weights, config))
saving_file_name = f"model-{layer_id + 1:05d}-of-{num_layers:05d}.safetensors"
for name in list(pt_layer_weights.keys()):
converted_st_map["weight_map"][name] = saving_file_name
total_size += get_torch_storage_size(pt_layer_weights.get(name))
save_file(pt_layer_weights, os.path.join(output_path, saving_file_name))
logger.info(f"saving weights in layer-{layer_id + 1} to file {saving_file_name}")
converted_st_map["metadata"]["total_size"] = total_size
converted_model_index_file = os.path.join(output_path, "model.safetensors.index.json")
with open(converted_model_index_file, "w") as f:
json_string = json.dumps(converted_st_map, default=lambda x: x.__dict__, sort_keys=False, indent=2)
f.write(json_string)
set_safe_mode_for_file_or_dir(converted_model_index_file)
def ms_safetensors_convertor(input_path, output_path, config):
"""Convert safetensors format checkpoint"""
logger.info(">>>>>>>>>>>>>>>>>>>>>>> Start convert checkpoint from safetensors format.")
layer_st_map = layers_model_file_map(input_path)
num_layers = config.model.model_config.num_layers
converted_st_map = defaultdict()
converted_st_map["weight_map"] = defaultdict()
converted_st_map["metadata"] = defaultdict()
total_size = 0
for layer_id in range(num_layers):
if layer_id == 0:
ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=True, is_last=False)
elif 0 < layer_id < num_layers - 1:
ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=False, is_last=False)
else:
ms_layer_weights = read_matched_file(layer_st_map, [layer_id], is_first=False, is_last=True)
pt_layer_weights = defaultdict()
if layer_id == 0:
pt_layer_weights.update(_model_preprocess_ms_to_pt(ms_layer_weights, config))
pt_layer_weights.update(attention_ms_to_pt(layer_id, ms_layer_weights, config))
pt_layer_weights.update(feed_forward_ms_to_pt(layer_id, ms_layer_weights, config))
if layer_id == num_layers - 1:
pt_layer_weights.update(_model_postprocess_ms_to_pt(ms_layer_weights, config))
saving_file_name = f"model-{layer_id + 1:05d}-of-{num_layers:05d}.safetensors"
for name in list(pt_layer_weights.keys()):
converted_st_map["weight_map"][name] = saving_file_name
total_size += get_torch_storage_size(pt_layer_weights.get(name))
save_file(pt_layer_weights, os.path.join(output_path, saving_file_name))
logger.info(f"saving weights in layer-{layer_id + 1} to file {saving_file_name}")
converted_st_map["metadata"]["total_size"] = total_size
converted_model_index_file = os.path.join(output_path, "model.safetensors.index.json")
with open(converted_model_index_file, "w") as f:
json_string = json.dumps(converted_st_map, default=lambda x: x.__dict__, sort_keys=False, indent=2)
f.write(json_string)
set_safe_mode_for_file_or_dir(converted_model_index_file)
def convert_ms_to_pt(input_path, output_path, config_path):
"""convert ms weight to huggingface."""
config = MindFormerConfig(config_path)
os.makedirs(output_path, exist_ok=True)
load_format = config.get('load_ckpt_format', "ckpt")
logger.info(f"load_format---------------->: {load_format}")
logger.info(f"Loading mindspore checkpoint in {input_path} ...")
if load_format == "safetensors":
ms_safetensors_convertor(input_path, output_path, config)
else:
ms_ckpt_convertor(input_path, output_path, config)
logger.info("Finish converting mindspore checkpoints into Huggingface checkpoints!")