"""
Convert QwenVL weight.
Support huggingface format.
"""
import argparse
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM
import mindspore as ms
from mindformers.tools import str2bool
QWEN_ATTENTION_WEIGHT_NAME = "attn.c_attn.weight"
QWEN_ATTENTION_BIAS_NAME = "attn.c_attn.bias"
QWEN_EMBEDDING_WEIGHT_NAME = "wte.embedding_weight"
QWEN_LM_HEAD_WEIGHT_NAME = "lm_head.weight"
def _qwen_name_replace(name: str):
"""replace huggingface parameter name to mindformers."""
name = name.replace(".h.", ".layers.")
name = name.replace(".wte.weight", ".wte.embedding_weight")
name = name.replace("attn.c_proj.", "attention.wo.")
name = name.replace("ln_1.", "attention_norm.")
name = name.replace("ln_2.", "ffn_norm.")
name = name.replace("mlp.w1.", "feed_forward.w1.")
name = name.replace("mlp.w2.", "feed_forward.w3.")
name = name.replace("mlp.c_proj.", "feed_forward.w2.")
name = "llm_model." + name
return name
def convert_qwen_attention_weight(name, value, ckpt_weights, dtype=ms.float16, use_qkv_concat=False):
"""convert attention weight of qwen"""
if not use_qkv_concat:
split_arr = np.array_split(value, 3)
attention_weight_names = ["attention.wq.weight", "attention.wk.weight", "attention.wv.weight"]
for index, split_arr_item in enumerate(split_arr):
cur_name = name.replace(QWEN_ATTENTION_WEIGHT_NAME, attention_weight_names[index])
ckpt_weights.append({"name": cur_name, "data": ms.Tensor(split_arr_item, dtype=dtype)})
else:
cur_name = name.replace(QWEN_ATTENTION_WEIGHT_NAME, "attention.w.weight")
ckpt_weights.append({"name": cur_name, "data": ms.Tensor(value, dtype=dtype)})
def convert_qwen_attention_bias(name, value, ckpt_weights, dtype=ms.float16, use_qkv_concat=False):
"""convert attention bias of qwen"""
if not use_qkv_concat:
split_arr = np.array_split(value, 3)
attention_bias_names = ["attention.wq.bias", "attention.wk.bias", "attention.wv.bias"]
for index, split_arr_item in enumerate(split_arr):
cur_name = name.replace(QWEN_ATTENTION_BIAS_NAME, attention_bias_names[index])
ckpt_weights.append({"name": cur_name, "data": ms.Tensor(split_arr_item, dtype=dtype)})
else:
cur_name = name.replace(QWEN_ATTENTION_BIAS_NAME, "attention.w.bias")
ckpt_weights.append({"name": cur_name, "data": ms.Tensor(value, dtype=dtype)})
def convert_vit_resampler_attention(name, value, ckpt_weights, dtype=ms.float16):
"""convert attention bias of vit resampler"""
if "ln" in name:
ms_name = name.replace("weight", "gamma")
ms_name = ms_name.replace("bias", "beta")
elif "in_proj_weight" in name:
value = np.array_split(value, 3)
ms_name = [name.replace("in_proj_weight", f"dense{i}.weight") for i in (1, 2, 3)]
elif "in_proj_bias" in name:
value = np.array_split(value, 3)
ms_name = [name.replace("in_proj_bias", f"dense{i}.bias") for i in (1, 2, 3)]
elif "out_proj" in name:
if "weight" in name:
value = np.transpose(value, (1, 0))
ms_name = name.replace("out_proj", "projection")
else:
ms_name = name
if not isinstance(ms_name, (tuple, list)):
cur_ms_name = (ms_name,)
else:
cur_ms_name = ms_name
if not isinstance(value, (tuple, list)):
cur_value = (value,)
else:
cur_value = value
for n, p in zip(cur_ms_name, cur_value):
if n != name:
print(f"name: {name}->{n}")
ckpt_weights.append({"name": n, "data": ms.Tensor(p, dtype=dtype)})
def convert_vit_transformer_attn(name, value, ckpt_weights, dtype=ms.float16, vit_num_head=16):
"""convert attention bias of vit transformer"""
if "in_proj" in name:
if value.shape[0] % (3 * vit_num_head) != 0:
raise ValueError(f"The 3 * vit_num_head({3 * vit_num_head}) must be divisible "
f"by value.shape[0]({value.shape[0]}).")
value = np.array_split(value, 3 * vit_num_head)
if "weight" in name:
value = [np.vstack(value[i::3]) for i in range(3)]
else:
value = [np.concatenate(value[i::3]) for i in range(3)]
ms_name = [name.replace("in_proj", f"dense{i}") for i in (1, 2, 3)]
elif "out_proj" in name:
if "weight" in name:
value = np.transpose(value, (1, 0))
ms_name = name.replace("out_proj", "projection")
else:
ms_name = name
if not isinstance(ms_name, (tuple, list)):
cur_ms_name = (ms_name,)
else:
cur_ms_name = ms_name
if not isinstance(value, (tuple, list)):
cur_value = (value,)
else:
cur_value = value
for n, p in zip(cur_ms_name, cur_value):
if n != name:
print(f"name: {name}->{n}")
ckpt_weights.append({"name": n, "data": ms.Tensor(p, dtype=dtype)})
def _vit_name_replace(name: str):
"""replace vit module name"""
if "ln" in name:
name = name.replace("weight", "gamma")
name = name.replace("bias", "beta")
elif "in_proj_weight" in name:
name = name.replace("in_proj_weight", "in_proj.weight")
elif "in_proj_bias" in name:
name = name.replace("in_proj_bias", "in_proj.bias")
elif "token_embedding.weight" in name:
name = name.replace("token_embedding.weight", "token_embedding.embedding_table")
return name
def convert_vit_weight(name, value, ckpt_weights, dtype, vit_num_head):
"""convert vit weights"""
name = name.replace("transformer.visual.", "")
name = "vision_encoder." + name
if "attn_pool" in name:
convert_vit_resampler_attention(name, value, ckpt_weights, dtype)
elif "attn" in name:
convert_vit_transformer_attn(name, value, ckpt_weights, dtype, vit_num_head)
else:
ms_name = _vit_name_replace(name)
if ms_name != name:
print(f"name: {name}->{ms_name}")
ckpt_weights.append({"name": ms_name, "data": ms.Tensor(value, dtype=dtype)})
def convert_pt_to_ms(input_path, output_path, torch_dtype=torch.float16, dtype=ms.float16, **kwargs):
"""Convert huggingface weights files to mindspore ckpt format."""
vit_num_head = kwargs.get("vit_num_head", 16)
enable_emb_opt = kwargs.get("enable_emb_opt", False)
emb_strategy = kwargs.get("emb_strategy", None)
use_qkv_concat = kwargs.get("use_qkv_concat", False)
model = AutoModelForCausalLM.from_pretrained(input_path, device_map="cpu")
ckpt_weights = []
for name, param in model.named_parameters():
print(f"Parameter (name={name}, shape={param.shape}, dtype={param.dtype}, requires_grad={param.requires_grad})")
try:
value = param.detach().numpy()
except TypeError:
value = param.detach().to(torch_dtype).cpu().numpy()
print(f"dtype: {param.dtype}->{value.dtype}")
if "visual" in name:
convert_vit_weight(name, value, ckpt_weights, dtype, vit_num_head)
else:
ms_name = _qwen_name_replace(name)
if ms_name != name:
print(f"name: {name}->{ms_name}")
if enable_emb_opt:
if emb_strategy is None:
raise ValueError("num_cards should be set when enable_emb_opt is True")
if QWEN_EMBEDDING_WEIGHT_NAME in ms_name or QWEN_LM_HEAD_WEIGHT_NAME in ms_name:
vocab_size = value.shape[0]
mult, remainder = divmod(vocab_size, 512 * emb_strategy)
if remainder > 0:
mult += 1
new_vocab_size = mult * 512 * emb_strategy
pad = np.zeros((new_vocab_size - vocab_size, value.shape[1]), dtype=value.dtype)
value = np.concatenate([value, pad], axis=0)
print(f"{ms_name} shape from {vocab_size} to {new_vocab_size}")
ckpt_weights.append({"name": ms_name, "data": ms.Tensor(value, dtype=dtype)})
continue
if QWEN_ATTENTION_WEIGHT_NAME in ms_name:
convert_qwen_attention_weight(ms_name, value, ckpt_weights, dtype, use_qkv_concat)
continue
if QWEN_ATTENTION_BIAS_NAME in ms_name:
convert_qwen_attention_bias(ms_name, value, ckpt_weights, dtype, use_qkv_concat)
continue
ckpt_weights.append({"name": ms_name, "data": ms.Tensor(value, dtype=dtype)})
print(f"Saving converted weights to {output_path}...")
ms.save_checkpoint(ckpt_weights, output_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen convert script")
parser.add_argument("--torch_ckpt_dir",
required=True,
help="The torch checkpoint path.")
parser.add_argument("--mindspore_ckpt_path",
default="./run/qwen_7b_ms.ckpt",
help="The output checkpoint path.")
parser.add_argument("--dtype",
default="float16",
choices=["float16", "float32", "bfloat16"],
help="The data type of the converted weight.")
parser.add_argument("--enable_emb_opt",
default=None,
type=str2bool,
help="Enable embedding optimization.")
parser.add_argument("--emb_strategy",
default=None,
type=int,
help="parallel strategy for embedding")
parser.add_argument("--vit_num_head",
default=16,
type=int,
help="The number of head in ViT.")
parser.add_argument("--use_qkv_concat",
default=None,
type=str2bool,
help="Whether to use qkv concat in attention weight.")
args = parser.parse_args()
mindspore_ckpt_path = os.path.expanduser(args.mindspore_ckpt_path)
output_dir = os.path.dirname(mindspore_ckpt_path)
if output_dir and not os.path.exists(output_dir):
os.mkdir(output_dir)
mapping = {
"float16": ms.float16,
"float32": ms.float32,
"bfloat16": ms.bfloat16
}
convert_pt_to_ms(input_path=args.torch_ckpt_dir, output_path=mindspore_ckpt_path, torch_dtype=torch.float32,
dtype=mapping.get(args.dtype), vit_num_head=args.vit_num_head, enable_emb_opt=args.enable_emb_opt,
emb_strategy=args.emb_strategy, use_qkv_concat=args.use_qkv_concat)