"""
transform huggingface model to mindspore ckpt.
"""
import argparse
import json
import os
from collections import defaultdict
import math
import multiprocessing
from glob import glob
import warnings
import numpy as np
import torch
from safetensors.torch import load_file
import mindspore as ms
from mindformers.tools.utils import set_safe_mode_for_file_or_dir
dtype_map = {
'fp32': ms.float32,
'bf16': ms.bfloat16,
'fp16': ms.float16
}
default_config = {
'num_routed_experts': 256,
'n_head': 128,
'qk_nope_head_dim': 128,
'qk_rope_head_dim': 64,
'v_head_dim': 128,
'num_layers': 61,
'num_nextn_predict_layers': 1,
'first_k_dense_replace': 3,
'dtype': ms.bfloat16,
'use_grouped_gemm': True,
'save_format': "safetensors"
}
infer_config = {
'num_head': 128,
'qk_rope_head_dim': 64,
'qk_nope_head_dim': 128,
'kv_lora_rank': 512,
'v_head_dim': 128,
'rope_dim': 192,
'kv_head_dim': 576,
'total_layer_num': 62,
'num_routed_experts': 256
}
def str2bool(b: str):
"""String convert to Bool."""
if b.lower() in ["false"]:
output = False
elif b.lower() in ["true"]:
output = True
else:
raise Exception("Invalid Bool Value")
return output
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor, efficiently handling cases where
`weight` is not a multiple of `block_size` by broadcasting `scale`.
Args:
weight (torch.Tensor): The quantized weight tensor of shape (M, N).
scale (torch.Tensor): The scale tensor of shape (M // block_size, N // block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `weight`, converted to the default dtype.
Raises:
AssertionError: If `scale` dimensions do not align with `weight` shape after scaling.
"""
m, n = weight.shape
scale_m, scale_n = scale.shape
if scale_m != (m + block_size - 1) // block_size:
raise ValueError("Mismatch in scale rows and weight rows.")
if scale_n != (n + block_size - 1) // block_size:
raise ValueError("Mismatch in scale columns and weight columns.")
weight = weight.to(torch.float32)
scale_expanded = scale.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
scale_expanded = scale_expanded[:m, :n]
dequantized_weight = weight * scale_expanded
dequantized_weight = dequantized_weight.to(torch.get_default_dtype())
return dequantized_weight
def dequant_layer_weights(layer_id, pt_layer_weights):
"""Dequanting weights in a layer"""
dequanted_weights = {}
for weight_name, weight in pt_layer_weights.items():
if weight_name.endswith("_scale_inv"):
continue
if weight.element_size() == 1 and (f"model.layers.{layer_id}." in weight_name):
scale_inv_name = f"{weight_name}_scale_inv"
try:
scale_inv = pt_layer_weights.get(scale_inv_name)
dequanted_weights[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping dequanting")
dequanted_weights[weight_name] = weight
else:
dequanted_weights[weight_name] = weight
return dequanted_weights
def plain_name_replace(weight_name: str):
"""Weight name replacing for pre/post-process module"""
weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')
weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight')
return weight_name
def mla_name_replace(weight_name: str):
"""Weight name replacing for MLA module weights"""
weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.q_proj.')
weight_name = weight_name.replace('.self_attn.q_a_proj.', '.attention.q2l_proj.')
weight_name = weight_name.replace('.self_attn.q_a_layernorm.', '.attention.lq_norm.')
weight_name = weight_name.replace('.self_attn.q_b_proj.', '.attention.l2q_proj.')
weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.', '.attention.kv2l.')
weight_name = weight_name.replace('.self_attn.kv_a_layernorm.', '.attention.lkv_norm.')
weight_name = weight_name.replace('.self_attn.kv_b_proj.', '.attention.lkv2kv.')
weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
return weight_name
def mlp_name_replace(weight_name: str, use_grouped_gemm: bool = True):
"""Weight name replacing for MLP module, including MoE"""
weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.')
weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.')
weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.')
weight_name = weight_name.replace('mlp.shared_experts.gate_proj.', 'feed_forward.shared_experts.w1.')
weight_name = weight_name.replace('mlp.shared_experts.down_proj.', 'feed_forward.shared_experts.w2.')
weight_name = weight_name.replace('mlp.shared_experts.up_proj.', 'feed_forward.shared_experts.w3.')
bmm_key = 'feed_forward.routed_experts.router.dense.weight'
gmm_key = 'feed_forward.routed_experts.router_dense.weight'
weight_name = weight_name.replace('mlp.gate.weight', gmm_key if use_grouped_gemm else bmm_key)
bmm_key = 'feed_forward.routed_experts.router.router.topk_bias'
gmm_key = 'feed_forward.routed_experts.topk_bias'
weight_name = weight_name.replace('mlp.gate.e_score_correction_bias', gmm_key if use_grouped_gemm else bmm_key)
return weight_name
def mtp_name_replace(weight_name: str, current_layer_id: int, mtp_layer_id: int):
"""replace weight name for MultiPredictionToken module"""
weight_name = weight_name.replace(f"model.layers.{current_layer_id}.enorm",
f"model.mtp_hidden_fusers.{mtp_layer_id}.norm_emb")
weight_name = weight_name.replace(f"model.layers.{current_layer_id}.hnorm",
f"model.mtp_hidden_fusers.{mtp_layer_id}.norm")
weight_name = weight_name.replace(f"model.layers.{current_layer_id}.eh_proj",
f"model.mtp_hidden_fusers.{mtp_layer_id}.dense")
weight_name = weight_name.replace(f"model.layers.{current_layer_id}.shared_head.norm",
f"model.mtp_norms.{mtp_layer_id}")
return weight_name
def load_data_pt(file_name):
return load_file(file_name, device="cpu")
def layers_model_file_map(file_path):
"""Get weight-file map"""
layer_st_map = defaultdict(set)
weight_map_file = os.path.join(file_path, "model.safetensors.index.json")
if os.path.exists(weight_map_file):
with open(weight_map_file) as f:
weights_map = json.load(f)
weights_map = weights_map["weight_map"]
else:
warnings.warn(f"Cannot find weight map file model.safetensors.index.json in path {file_path}, " \
f"Trying to load one safetensor file ...")
files = sorted(glob(os.path.join(file_path, "*.safetensors")))
if not files:
raise ValueError(f"No safetensors files found in path {file_path}")
weight_file = files[0].split("/")[-1]
keys = load_data_pt(os.path.join(file_path, weight_file)).keys()
weights_map = {}
for k in keys:
weights_map[k] = weight_file
for weight_key, value in weights_map.items():
if weight_key.startswith("model.layers."):
layer_name = int(weight_key.split('model.layers.')[1].split('.')[0])
layer_st_map[layer_name].add(os.path.join(file_path, value))
else:
layer_st_map[weight_key].add(os.path.join(file_path, value))
return layer_st_map
def read_matched_file_pt(layer_st_map, layer_list, is_first, is_last):
"""Load weights into dict for specified layers"""
st_file_list = []
for layer in layer_list:
st_file_list.extend(list(layer_st_map[layer]))
if is_first:
st_file_list.extend(list(layer_st_map["model.embed_tokens.weight"]))
if is_last:
st_file_list.extend(list(layer_st_map["model.norm.weight"]))
st_file_list.extend(list(layer_st_map["lm_head.weight"]))
st_file_list = list(set(st_file_list))
weights = {}
for st_file in st_file_list:
current_weight = load_data_pt(st_file)
weights.update(current_weight)
return weights
def _mla_pt_to_ms(layer_id, pt_layer_weights, config):
"""Processing weights in MLA module"""
n_head = config['n_head']
qk_nope_head_dim = config['qk_nope_head_dim']
qk_rope_head_dim = config['qk_rope_head_dim']
v_head_dim = config['v_head_dim']
q_a_proj_key = f"model.layers.{layer_id}.self_attn.q_a_proj.weight"
kv_a_proj_key = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.weight"
o_proj_key = f"model.layers.{layer_id}.self_attn.o_proj.weight"
q_a_layernorm_key = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight"
kv_a_layernorm_key = f"model.layers.{layer_id}.self_attn.kv_a_layernorm.weight"
q_b_proj_key = f"model.layers.{layer_id}.self_attn.q_b_proj.weight"
kv_b_proj_key = f"model.layers.{layer_id}.self_attn.kv_b_proj.weight"
input_norm_key = f"model.layers.{layer_id}.input_layernorm.weight"
post_attn_norm_key = f"model.layers.{layer_id}.post_attention_layernorm.weight"
q_a_proj = pt_layer_weights.pop(q_a_proj_key)
kv_a_proj = pt_layer_weights.pop(kv_a_proj_key)
o_proj = pt_layer_weights.pop(o_proj_key)
q_a_layernorm = pt_layer_weights.pop(q_a_layernorm_key)
kv_a_layernorm = pt_layer_weights.pop(kv_a_layernorm_key)
q_b_proj = pt_layer_weights.pop(q_b_proj_key)
kv_b_proj = pt_layer_weights.pop(kv_b_proj_key)
input_norm = pt_layer_weights.pop(input_norm_key)
post_attn_norm = pt_layer_weights.pop(post_attn_norm_key)
mla_weight_dict = defaultdict()
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
qk_nope, qk_rope = q_b_proj.reshape(n_head, qk_head_dim, -1).split([qk_nope_head_dim, qk_rope_head_dim], dim=1)
qk_rope = qk_rope.reshape(qk_rope.shape[0], qk_rope.shape[1] // 2, 2, -1).permute(0, 2, 1, 3)
qk_nope = qk_nope.reshape(-1, qk_nope.shape[-1])
qk_rope = qk_rope.reshape(-1, qk_rope.shape[-1])
qk_nope_key = mla_name_replace(q_b_proj_key).replace(".l2q_proj.", ".l2q_nope_proj.")
qk_rope_key = mla_name_replace(q_b_proj_key).replace(".l2q_proj.", ".l2q_pe_proj.")
mla_weight_dict[qk_nope_key] = qk_nope.clone()
mla_weight_dict[qk_rope_key] = qk_rope.clone()
kv_lora_rank = kv_a_proj.shape[0] - qk_rope_head_dim
latent_kv, k_rope = kv_a_proj.split([kv_lora_rank, qk_rope_head_dim], dim=0)
k_rope = k_rope.reshape(k_rope.shape[0] // 2, 2, -1).permute(1, 0, 2).reshape(-1, k_rope.shape[-1])
latent_kv_key = mla_name_replace(kv_a_proj_key).replace(".kv2l.", ".kv2l_latent_kv.")
k_rope_key = mla_name_replace(kv_a_proj_key).replace(".kv2l.", ".kv2l_k_pe.")
mla_weight_dict[latent_kv_key] = latent_kv.clone()
mla_weight_dict[k_rope_key] = k_rope.clone()
kv_head_dim = qk_nope_head_dim + v_head_dim
k_nope, v = kv_b_proj.reshape(n_head, kv_head_dim, -1).split([qk_nope_head_dim, v_head_dim], dim=1)
k_nope = k_nope.reshape(-1, k_nope.shape[-1])
v = v.reshape(-1, v.shape[-1])
k_nope_key = mla_name_replace(kv_b_proj_key).replace(".lkv2kv.", ".lkv2kv_k_nope.")
v_key = mla_name_replace(kv_b_proj_key).replace(".lkv2kv.", ".lkv2kv_v.")
mla_weight_dict[k_nope_key] = k_nope
mla_weight_dict[v_key] = v
q_a_proj_key = mla_name_replace(q_a_proj_key)
mla_weight_dict[q_a_proj_key] = q_a_proj.clone()
o_proj_key = mla_name_replace(o_proj_key)
mla_weight_dict[o_proj_key] = o_proj.clone()
q_a_layernorm_key = mla_name_replace(q_a_layernorm_key)
mla_weight_dict[q_a_layernorm_key] = q_a_layernorm.clone()
kv_a_layernorm_key = mla_name_replace(kv_a_layernorm_key)
mla_weight_dict[kv_a_layernorm_key] = kv_a_layernorm.clone()
input_norm_key = mla_name_replace(input_norm_key)
mla_weight_dict[input_norm_key] = input_norm.clone()
post_attn_norm_key = mla_name_replace(post_attn_norm_key)
mla_weight_dict[post_attn_norm_key] = post_attn_norm.clone()
return mla_weight_dict
def _mlp_pt_to_ms(layer_id, pt_layer_weights, config):
"""Processing weights in MLP/MoE module"""
num_routed_experts = config['num_routed_experts']
first_k_dense_replace = config['first_k_dense_replace']
use_grouped_gemm = config['use_grouped_gemm']
mlp_weight_dict = defaultdict()
if layer_id < first_k_dense_replace:
gate_proj_key = f"model.layers.{layer_id}.mlp.gate_proj.weight"
up_proj_key = f"model.layers.{layer_id}.mlp.up_proj.weight"
down_proj_key = f"model.layers.{layer_id}.mlp.down_proj.weight"
gate_proj = pt_layer_weights.pop(gate_proj_key)
up_proj = pt_layer_weights.pop(up_proj_key)
down_proj = pt_layer_weights.pop(down_proj_key)
gate_proj_key = mlp_name_replace(gate_proj_key)
up_proj_key = mlp_name_replace(up_proj_key)
down_proj_key = mlp_name_replace(down_proj_key)
mlp_weight_dict[gate_proj_key] = gate_proj.clone()
mlp_weight_dict[up_proj_key] = up_proj.clone()
mlp_weight_dict[down_proj_key] = down_proj.clone()
else:
router_weight_key = f"model.layers.{layer_id}.mlp.gate.weight"
router_correct_bias_key = f"model.layers.{layer_id}.mlp.gate.e_score_correction_bias"
shared_experts_gate_proj_key = f"model.layers.{layer_id}.mlp.shared_experts.gate_proj.weight"
shared_experts_up_proj_key = f"model.layers.{layer_id}.mlp.shared_experts.up_proj.weight"
shared_experts_down_proj_key = f"model.layers.{layer_id}.mlp.shared_experts.down_proj.weight"
router_weight = pt_layer_weights.pop(router_weight_key)
router_weight = router_weight[:num_routed_experts, :]
router_correct_bias = pt_layer_weights.pop(router_correct_bias_key)
router_correct_bias = router_correct_bias[:num_routed_experts]
shared_experts_gate_proj = pt_layer_weights.pop(shared_experts_gate_proj_key)
shared_experts_up_proj = pt_layer_weights.pop(shared_experts_up_proj_key)
shared_experts_down_proj = pt_layer_weights.pop(shared_experts_down_proj_key)
gate_proj_list = []
up_proj_list = []
down_proj_list = []
for expert_id in range(num_routed_experts):
gate_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.gate_proj.weight"
up_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.up_proj.weight"
down_proj_key = f"model.layers.{layer_id}.mlp.experts.{expert_id}.down_proj.weight"
gate_proj = pt_layer_weights.pop(gate_proj_key)
up_proj = pt_layer_weights.pop(up_proj_key)
down_proj = pt_layer_weights.pop(down_proj_key)
gate_proj_list.append(gate_proj)
up_proj_list.append(up_proj)
down_proj_list.append(down_proj)
expert_gate_proj = torch.stack(gate_proj_list, 0)
expert_up_proj = torch.stack(up_proj_list, 0)
expert_down_proj = torch.stack(down_proj_list, 0)
router_weight_key = mlp_name_replace(router_weight_key, use_grouped_gemm)
router_correct_bias_key = mlp_name_replace(router_correct_bias_key, use_grouped_gemm)
shared_experts_gate_proj_key = mlp_name_replace(shared_experts_gate_proj_key)
shared_experts_up_proj_key = mlp_name_replace(shared_experts_up_proj_key)
shared_experts_down_proj_key = mlp_name_replace(shared_experts_down_proj_key)
mlp_weight_dict[router_weight_key] = router_weight.clone()
mlp_weight_dict[router_correct_bias_key] = router_correct_bias.clone()
mlp_weight_dict[shared_experts_gate_proj_key] = shared_experts_gate_proj.clone()
mlp_weight_dict[shared_experts_up_proj_key] = shared_experts_up_proj.clone()
mlp_weight_dict[shared_experts_down_proj_key] = shared_experts_down_proj.clone()
if use_grouped_gemm:
expert_gate_proj = expert_gate_proj.transpose(1, 2)
expert_up_proj = expert_up_proj.transpose(1, 2)
expert_down_proj = expert_down_proj.transpose(1, 2)
weight1 = torch.cat((expert_gate_proj, expert_up_proj), -1)
weight2 = expert_down_proj
weight1_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1"
weight2_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2"
mlp_weight_dict[weight1_key] = weight1.clone()
mlp_weight_dict[weight2_key] = weight2.clone()
else:
expert_gate_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1.weight"
expert_up_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3.weight"
expert_down_proj_key = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2.weight"
mlp_weight_dict[expert_gate_proj_key] = expert_gate_proj.clone()
mlp_weight_dict[expert_up_proj_key] = expert_up_proj.clone()
mlp_weight_dict[expert_down_proj_key] = expert_down_proj.clone()
return mlp_weight_dict
def _mtp_pt_to_ms(layer_id, pt_layer_weights, config):
"""Processing weights in MTP module, the shared weights will be ignored"""
num_layers = config["num_layers"]
mtp_layer_id = layer_id - num_layers
pt_layer_weights.pop(f"model.layers.{layer_id}.embed_tokens.weight")
pt_layer_weights.pop(f"model.layers.{layer_id}.shared_head.head.weight")
enorm_key = f"model.layers.{layer_id}.enorm.weight"
hnorm_key = f"model.layers.{layer_id}.hnorm.weight"
e_proj_key = f"model.layers.{layer_id}.eh_proj.weight"
norm_out_key = f"model.layers.{layer_id}.shared_head.norm.weight"
enorm = pt_layer_weights.pop(enorm_key)
hnorm = pt_layer_weights.pop(hnorm_key)
e_proj = pt_layer_weights.pop(e_proj_key)
norm_out = pt_layer_weights.pop(norm_out_key)
mtp_weight_dict = defaultdict()
enorm_key = mtp_name_replace(enorm_key, layer_id, mtp_layer_id)
hnorm_key = mtp_name_replace(hnorm_key, layer_id, mtp_layer_id)
e_proj_key = mtp_name_replace(e_proj_key, layer_id, mtp_layer_id)
norm_out_key = mtp_name_replace(norm_out_key, layer_id, mtp_layer_id)
mtp_weight_dict[enorm_key] = enorm.clone()
mtp_weight_dict[hnorm_key] = hnorm.clone()
mtp_weight_dict[e_proj_key] = e_proj.clone()
mtp_weight_dict[norm_out_key] = norm_out.clone()
return mtp_weight_dict
def _model_preprocess_pt_to_ms(pt_layer_weights):
"""Processing weights in prepross module"""
emb_weight_key = "model.embed_tokens.weight"
emb_weight = pt_layer_weights.pop(emb_weight_key)
emb_weight_key = plain_name_replace(emb_weight_key)
plain_weight_dict = defaultdict()
plain_weight_dict[emb_weight_key] = emb_weight.clone()
return plain_weight_dict
def _model_postprocess_pt_to_ms(pt_layer_weights):
"""Processing weights in postpross module"""
final_norm_key = "model.norm.weight"
lm_head_key = "lm_head.weight"
final_norm = pt_layer_weights.pop(final_norm_key)
lm_head = pt_layer_weights.pop(lm_head_key)
final_norm_key = plain_name_replace(final_norm_key)
lm_head_key = plain_name_replace(lm_head_key)
plain_weight_dict = defaultdict()
plain_weight_dict[final_norm_key] = final_norm.clone()
plain_weight_dict[lm_head_key] = lm_head.clone()
return plain_weight_dict
def ms_ckpt_convertor(input_path, output_path, config):
"""Convert to ckpt format checkpoint"""
if output_path.endswith(".ckpt"):
saving_file = output_path
else:
saving_file = os.path.join(output_path, "checkpoints.ckpt")
layer_st_map = layers_model_file_map(input_path)
torch.set_default_dtype(torch.bfloat16)
dtype = config["dtype"]
num_layers = config["num_layers"]
num_nextn_predict_layers = config["num_nextn_predict_layers"]
total_num_layers = num_layers + num_nextn_predict_layers
ms_weights = defaultdict()
for layer_id in range(total_num_layers):
if layer_id == 0:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=True, is_last=False)
elif layer_id == total_num_layers - 1:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=False, is_last=True)
else:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=False, is_last=False)
pt_layer_weights = dequant_layer_weights(layer_id, pt_layer_weights)
if layer_id == 0:
ms_weights.update(_model_preprocess_pt_to_ms(pt_layer_weights))
ms_weights.update(_mla_pt_to_ms(layer_id, pt_layer_weights, config))
ms_weights.update(_mlp_pt_to_ms(layer_id, pt_layer_weights, config))
if layer_id > num_layers - 1:
ms_weights.update(_mtp_pt_to_ms(layer_id, pt_layer_weights, config))
if layer_id == total_num_layers - 1:
ms_weights.update(_model_postprocess_pt_to_ms(pt_layer_weights))
to_save_ckpt = []
for name in list(ms_weights.keys()):
value = ms_weights.pop(name).to(torch.float32).numpy()
tmp_dtype = dtype
if "norm" in name or "router.dense" in name or "topk_bias" in name:
tmp_dtype = ms.float32
to_save_ckpt.append({'name': name, 'data': ms.Tensor(value, dtype=tmp_dtype)})
print(f"Saving weights to file {saving_file}")
ms.save_checkpoint(to_save_ckpt, saving_file, format="ckpt")
def ms_safetensors_convertor(input_path, output_path, config):
"""Convert to safetensors format checkpoint"""
layer_st_map = layers_model_file_map(input_path)
torch.set_default_dtype(torch.bfloat16)
dtype = config["dtype"]
num_layers = config["num_layers"]
num_nextn_predict_layers = config["num_nextn_predict_layers"]
total_num_layers = num_layers + num_nextn_predict_layers
converted_st_map = defaultdict()
for layer_id in range(total_num_layers):
if layer_id == 0:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=True, is_last=False)
elif layer_id == total_num_layers - 1:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=False, is_last=True)
else:
pt_layer_weights = read_matched_file_pt(layer_st_map, [layer_id], is_first=False, is_last=False)
pt_layer_weights = dequant_layer_weights(layer_id, pt_layer_weights)
ms_layer_weights = defaultdict()
if layer_id == 0:
ms_layer_weights.update(_model_preprocess_pt_to_ms(pt_layer_weights))
ms_layer_weights.update(_mla_pt_to_ms(layer_id, pt_layer_weights, config))
ms_layer_weights.update(_mlp_pt_to_ms(layer_id, pt_layer_weights, config))
if layer_id > num_layers - 1:
ms_layer_weights.update(_mtp_pt_to_ms(layer_id, pt_layer_weights, config))
if layer_id == total_num_layers - 1:
ms_layer_weights.update(_model_postprocess_pt_to_ms(pt_layer_weights))
to_save_ckpt = []
saving_file = f"ms-model-{layer_id+1:05d}-of-{total_num_layers:05d}.safetensors"
for name in list(ms_layer_weights.keys()):
value = ms_layer_weights.pop(name).to(torch.float32).numpy()
tmp_dtype = dtype
if "norm" in name or "router.dense" in name or "topk_bias" in name:
tmp_dtype = ms.float32
to_save_ckpt.append({'name': name, 'data': ms.Tensor(value, dtype=tmp_dtype)})
converted_st_map[name] = saving_file
ms.save_checkpoint(to_save_ckpt, os.path.join(output_path, saving_file), format="safetensors")
print(f"saving weights in layer-{layer_id} to file {saving_file}")
converted_model_index_file = os.path.join(output_path, "param_name_map.json")
with open(converted_model_index_file, "w") as f:
json_string = json.dumps(converted_st_map, default=lambda x: x.__dict__, sort_keys=False, indent=2)
f.write(json_string)
set_safe_mode_for_file_or_dir(converted_model_index_file)
def convert_pt_to_ms(input_path, output_path, config=None):
"""convert hf weight to ms."""
if config is None:
config = default_config
os.makedirs(output_path, exist_ok=True)
save_format = config['save_format']
print(f"Trying to convert huggingface checkpoint in '{input_path}'.", flush=True)
if save_format == "ckpt":
ms_ckpt_convertor(input_path, output_path, config)
if save_format == "safetensors":
ms_safetensors_convertor(input_path, output_path, config)
print("Finish converting Huggingface checkpoints into mindspore checkpoints!")
def convert_ms_to_gmm(input_path, output_path):
"""convert ms routing ffn weight for gmm."""
params = ms.load_checkpoint(input_path)
for k, v in params.items():
if 'feed_forward.routed_experts.ffn.w1.weight' in k or \
'feed_forward.routed_experts.ffn.w2.weight' in k or \
'feed_forward.routed_experts.ffn.w3.weight' in k:
orig_tensor = ms.Tensor(v)
gmm_tensor = orig_tensor.transpose((0, 2, 1))
params[k] = ms.Parameter(gmm_tensor)
print(f"\rConvertion finished, the mindspore ckpt is saved in '{output_path}'.", flush=True)
def infer_name_replace(weight_name: str):
"""replace weight name"""
weight_name = weight_name.replace('embed_tokens.', 'tok_embeddings.')
weight_name = weight_name.replace('.self_attn.q_proj.', '.attention.q_proj.')
weight_name = weight_name.replace('.self_attn.q_a_proj.', '.attention.q2l_proj.')
weight_name = weight_name.replace('.self_attn.q_a_layernorm.', '.attention.lq_norm.')
weight_name = weight_name.replace('.self_attn.q_b_proj.', '.attention.l2q_proj.')
weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.', '.attention.kv2l.')
weight_name = weight_name.replace('.self_attn.kv_a_layernorm.', '.attention.lkv_norm.')
weight_name = weight_name.replace('.self_attn.kv_b_proj.', '.attention.lkv2kv.')
weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.')
weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.')
weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.')
weight_name = weight_name.replace('mlp.experts.', 'feed_forward.routed_experts.ffn.')
weight_name = weight_name.replace('mlp.shared_experts.gate_proj.', 'feed_forward.shared_experts.w1.')
weight_name = weight_name.replace('mlp.shared_experts.down_proj.', 'feed_forward.shared_experts.w2.')
weight_name = weight_name.replace('mlp.shared_experts.up_proj.', 'feed_forward.shared_experts.w3.')
weight_name = weight_name.replace('mlp.gate.weight', 'feed_forward.routed_experts.router.dense.weight')
weight_name = weight_name.replace('mlp.gate.e_score_correction_bias',
'feed_forward.routed_experts.router.e_score_correction_bias')
weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
weight_name = weight_name.replace('model.tok_embeddings.weight', 'model.tok_embeddings.embedding_weight')
weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight')
return weight_name
def infer_trans_rope_weight(weight):
"""process rope routed weight"""
w1 = weight[..., -infer_config['qk_rope_head_dim']::2, :]
w2 = weight[..., -infer_config['qk_rope_head_dim'] + 1::2, :]
weight[..., -infer_config['qk_rope_head_dim']:, :] = np.concatenate([w1, w2], axis=-2)
return weight
def infer_process_moe_routed_expert_ffn_weight(params_dict, dst_ms_dir, layer, ms_meta):
"""process moe routed expert weight"""
w1 = []
w2 = []
w3 = []
w1_keys = []
w2_keys = []
w3_keys = []
ffn_dtype = ms.bfloat16
for index in range(0, infer_config['num_routed_experts']):
w1_key = f"model.layers.{layer}.mlp.experts.{index}.gate_proj.weight"
w2_key = f"model.layers.{layer}.mlp.experts.{index}.down_proj.weight"
w3_key = f"model.layers.{layer}.mlp.experts.{index}.up_proj.weight"
ffn_dtype = params_dict[w1_key].dtype
w1.append(params_dict[w1_key].asnumpy())
w2.append(params_dict[w2_key].asnumpy())
w3.append(params_dict[w3_key].asnumpy())
w1_keys.append(w1_key)
w2_keys.append(w2_key)
w3_keys.append(w3_key)
params_w2 = {}
ffn_w2_key = f"model.layers.{layer}.feed_forward.routed_experts.ffn.w2.weight"
params_w2[ffn_w2_key] = ms.Parameter(ms.Tensor(np.stack(w2, axis=0).transpose(0, 2, 1), ffn_dtype), name=ffn_w2_key)
dst_w2_name = f"model_layer_{layer}_routed_experts_w2.safetensors"
ms_meta[ffn_w2_key] = dst_w2_name
w2_dst_path = f"{dst_ms_dir}/{dst_w2_name}"
ms.save_checkpoint(params_w2, w2_dst_path, format="safetensors")
params_w1 = {}
params_w3 = {}
ffn_w1_key = f"model.layers.{layer}.feed_forward.routed_experts.ffn.w1.weight"
ffn_w3_key = f"model.layers.{layer}.feed_forward.routed_experts.ffn.w3.weight"
params_w1[ffn_w1_key] = ms.Parameter(ms.Tensor(np.stack(w1, axis=0).transpose(0, 2, 1), ffn_dtype),
name=ffn_w1_key)
params_w3[ffn_w3_key] = ms.Parameter(ms.Tensor(np.stack(w3, axis=0).transpose(0, 2, 1), ffn_dtype),
name=ffn_w3_key)
dst_w1_name = f"model_layer_{layer}_routed_experts_w1.safetensors"
dst_w3_name = f"model_layer_{layer}_routed_experts_w3.safetensors"
ms_meta[ffn_w1_key] = dst_w1_name
ms_meta[ffn_w3_key] = dst_w3_name
w1_dst_path = f"{dst_ms_dir}/{dst_w1_name}"
w3_dst_path = f"{dst_ms_dir}/{dst_w3_name}"
ms.save_checkpoint(params_w1, w1_dst_path, format="safetensors")
ms.save_checkpoint(params_w3, w3_dst_path, format="safetensors")
for index in range(0, infer_config['num_routed_experts']):
params_dict.pop(w1_keys[index])
params_dict.pop(w2_keys[index])
params_dict.pop(w3_keys[index])
def infer_process_moe_shared_expert_ffn_weight(trans_params):
"""process moe shared expert weight"""
params_dict, ms_param, layer, ms_meta, dst_name = trans_params
w1_key = f"model.layers.{layer}.mlp.shared_experts.gate_proj.weight"
w2_key = f"model.layers.{layer}.mlp.shared_experts.down_proj.weight"
w3_key = f"model.layers.{layer}.mlp.shared_experts.up_proj.weight"
ffn_w2_key = f"model.layers.{layer}.feed_forward.shared_experts.w2.weight"
ms_param[ffn_w2_key] = params_dict[w2_key]
ms_meta[ffn_w2_key] = dst_name
ffn_w1_key = f"model.layers.{layer}.feed_forward.shared_experts.w1.weight"
ffn_w3_key = f"model.layers.{layer}.feed_forward.shared_experts.w3.weight"
ms_param[ffn_w1_key] = params_dict[w1_key]
ms_param[ffn_w3_key] = params_dict[w3_key]
ms_meta[ffn_w1_key] = dst_name
ms_meta[ffn_w3_key] = dst_name
params_dict.pop(w1_key)
params_dict.pop(w2_key)
params_dict.pop(w3_key)
def infer_process_dense_ffn_weight(trans_params):
"""process dense ffn weight"""
params_dict, ms_param, layer, ms_meta, dst_name = trans_params
w1_key = f"model.layers.{layer}.mlp.gate_proj.weight"
w2_key = f"model.layers.{layer}.mlp.down_proj.weight"
w3_key = f"model.layers.{layer}.mlp.up_proj.weight"
w1 = params_dict[w1_key]
w2 = params_dict[w2_key]
w3 = params_dict[w3_key]
w2_key_new = f"model.layers.{layer}.feed_forward.w2.weight"
ms_param[w2_key_new] = ms.Parameter(ms.Tensor(w2, w2.dtype), name=w2_key_new)
ms_meta[w2_key_new] = dst_name
w1_key_new = f"model.layers.{layer}.feed_forward.w1.weight"
w3_key_new = f"model.layers.{layer}.feed_forward.w3.weight"
ms_param[w1_key_new] = ms.Parameter(ms.Tensor(w1, w1.dtype), name=w1_key_new)
ms_param[w3_key_new] = ms.Parameter(ms.Tensor(w3, w3.dtype), name=w3_key_new)
ms_meta[w1_key_new] = dst_name
ms_meta[w3_key_new] = dst_name
params_dict.pop(w1_key)
params_dict.pop(w2_key)
params_dict.pop(w3_key)
def infer_convert_layer_weight(src_hf_dir, dst_ms_dir, layer, queue):
"""convert single layer weight"""
print(f"..... start convert layer {layer} .......", flush=True)
ms_meta = {}
with open(os.path.join(src_hf_dir, "model.safetensors.index.json"), "r") as fp:
hf_meta = json.load(fp).get('weight_map')
safetensor_files = set()
for param_key, param_path in hf_meta.items():
if f"model.layers.{layer}." in param_key:
safetensor_files.add(param_path)
params_dict = {}
for ckpt in safetensor_files:
src_path = f"{src_hf_dir}/{ckpt}"
p = ms.load_checkpoint(src_path, format="safetensors")
params_dict.update(p)
ms_param = {}
dst_name = f"model_layer_{layer}.safetensors"
if layer >= 3:
infer_process_moe_routed_expert_ffn_weight(params_dict, dst_ms_dir, layer, ms_meta)
infer_process_moe_shared_expert_ffn_weight((params_dict, ms_param, layer, ms_meta, dst_name))
else:
infer_process_dense_ffn_weight((params_dict, ms_param, layer, ms_meta, dst_name))
num_head = infer_config['num_head']
rope_dim = infer_config['rope_dim']
kv_head_dim = infer_config['kv_head_dim']
qk_nope_head_dim = infer_config['qk_nope_head_dim']
v_head_dim = infer_config['v_head_dim']
for key, value in params_dict.items():
if not key.startswith(f"model.layers.{layer}."):
continue
dtype = value.dtype
ms_key = infer_name_replace(key)
if "attention.l2q_proj.weight" in ms_key:
value = value.astype(np.float32).asnumpy()
value = value.reshape(num_head, rope_dim, -1)
weight = infer_trans_rope_weight(value)
weight = weight.reshape(num_head * rope_dim, -1)
ms_param[ms_key] = ms.Parameter(ms.Tensor(weight, dtype), name=ms_key)
elif "attention.kv2l.weight" in ms_key:
value = value.astype(np.float32).asnumpy()
value = value.reshape(kv_head_dim, -1)
weight = infer_trans_rope_weight(value)
ms_param[ms_key] = ms.Parameter(ms.Tensor(weight, dtype), name=ms_key)
elif ".attention.lkv2kv." in ms_key:
value = value.astype(np.float32).asnumpy()
lkv2kv_head = qk_nope_head_dim + v_head_dim
value = value.reshape(num_head, lkv2kv_head, -1)
value_k_nope, value_v = value[:, :qk_nope_head_dim, :], value[:, qk_nope_head_dim:, :]
value_k_nope = value_k_nope.reshape(-1, value_k_nope.shape[-1])
value_v = value_v.reshape(-1, value_v.shape[-1])
name_k_nope = ms_key.replace(".attention.lkv2kv.", ".attention.lkv2kv_k_nope.")
name_v = ms_key.replace(".attention.lkv2kv.", ".attention.lkv2kv_v.")
ms_param[name_k_nope] = ms.Parameter(ms.Tensor(value_k_nope, dtype), name=name_k_nope)
ms_param[name_v] = ms.Parameter(ms.Tensor(value_v, dtype), name=name_v)
ms_meta[name_k_nope] = dst_name
ms_meta[name_v] = dst_name
continue
else:
ms_param[ms_key] = ms.Parameter(ms.Tensor(value, dtype), name=ms_key)
ms_meta[ms_key] = dst_name
dst_path = os.path.join(dst_ms_dir, dst_name)
ms.save_checkpoint(ms_param, dst_path, format="safetensors")
queue.put(ms_meta)
print(f"..... end convert layer {layer} .......", flush=True)
def infer_convert_outer_weight(src_hf_dir, dst_ms_dir, ms_meta, param_json):
"""convert weight not in model"""
with open(f"{src_hf_dir}/{param_json}", "r") as fp:
hf_meta = json.load(fp)['weight_map']
safetensor_files = set()
for param_key, param_path in hf_meta.items():
if "model.layers." not in param_key:
safetensor_files.add(param_path)
params_dict = {}
for ckpt in safetensor_files:
src_path = f"{src_hf_dir}/{ckpt}"
p = ms.load_checkpoint(src_path, format="safetensors")
params_dict.update(p)
ms_param = {}
dst_name = "model.safetensors"
for key, value in params_dict.items():
if "model.layers." in key:
continue
ms_key = infer_name_replace(key)
ms_param[ms_key] = ms.Parameter(ms.Tensor(value), name=ms_key)
ms_meta[ms_key] = dst_name
dst_path = f"{dst_ms_dir}/{dst_name}"
ms.save_checkpoint(ms_param, dst_path, format="safetensors")
def infer_convert_weight(src_hf_dir, dst_ms_dir, worker_num, ms_meta, arg):
"""convert inference model weight """
infer_convert_outer_weight(src_hf_dir, dst_ms_dir, ms_meta, arg.param_json)
layers = infer_config['total_layer_num']
for index in range(math.ceil(layers / worker_num)):
process = []
queue = multiprocessing.Queue()
for j in range(index * worker_num, (index + 1) * worker_num, 1):
if j > layers - 1:
break
p = multiprocessing.Process(target=infer_convert_layer_weight, args=(src_hf_dir, dst_ms_dir, j, queue))
process.append(p)
p.start()
for p in process:
p.join()
while not queue.empty():
meta = queue.get()
ms_meta.update(meta)
def infer_trans_ckpt_pt_to_ms(src_hf_dir, dst_ms_dir, worker_num, arg):
"""main function of inference weight process"""
ms_meta = {}
os.makedirs(dst_ms_dir, exist_ok=True)
infer_convert_weight(src_hf_dir, dst_ms_dir, worker_num, ms_meta, arg)
path = f"{dst_ms_dir}/param_name_map.json"
with open(path, "w") as fp:
json.dump(ms_meta, fp, indent=4)
set_safe_mode_for_file_or_dir(path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_routed_experts', default=256, type=int)
parser.add_argument('--torch_ckpt_path', default=None, type=str)
parser.add_argument('--mindspore_ckpt_path', default=None, type=str)
parser.add_argument('--use_grouped_gemm', default=True, type=str2bool)
parser.add_argument('--pre_ckpt_path', default=None, type=str)
parser.add_argument('--dtype', default='bf16', type=str, choices=['fp16', 'bf16', 'fp32'])
parser.add_argument("--num_layers", default=61, type=int)
parser.add_argument("--num_nextn_predict_layers", default=1, type=int)
parser.add_argument("--first_k_dense_replace", default=3, type=int)
parser.add_argument("--n_head", default=128, type=int)
parser.add_argument("--qk_nope_head_dim", default=128, type=int)
parser.add_argument("--qk_rope_head_dim", default=64, type=int)
parser.add_argument("--v_head_dim", default=128, type=int)
parser.add_argument("--save_format", default="safetensors", choices=["safetensors", "ckpt"])
parser.add_argument("--infer", default=False, type=str2bool)
parser.add_argument('--worker_num', default=4, type=int)
parser.add_argument('--param_json', default="model.safetensors.index.json", type=str)
args = parser.parse_args()
if args.infer:
ms.set_device(device_target="CPU")
infer_trans_ckpt_pt_to_ms(src_hf_dir=args.torch_ckpt_path,
dst_ms_dir=args.mindspore_ckpt_path,
worker_num=args.worker_num,
arg=args)
else:
if args.pre_ckpt_path:
convert_ms_to_gmm(input_path=args.pre_ckpt_path, output_path=args.mindspore_ckpt_path)
else:
for key in default_config:
default_config[key] = getattr(args, key, default_config[key])
default_config['dtype'] = dtype_map.get(default_config['dtype'], default_config['dtype'])
convert_pt_to_ms(input_path=args.torch_ckpt_path,
output_path=args.mindspore_ckpt_path, config=default_config)