"""
@File : hf_to_mm.py
@Time : 2025/01/14
@Desc : qwen2vl huggingface模型转换成mindspeed-mm模型
huggingface模型目录:
Qwen2-VL-7B-Instruct/
├── chat_template.json
├── config.json
├── configuration.json
├── generation_config.json
├── LICENSE
├── merges.txt
├── model-00001-of-00005.safetensors
├── model-00002-of-00005.safetensors
├── model-00003-of-00005.safetensors
├── model-00004-of-00005.safetensors
├── model-00005-of-00005.safetensors
├── model.safetensors.index.json
├── preprocessor_config.json
├── README.md
├── tokenizer_config.json
├── tokenizer.json
└── vocab.json
mindspeed-mm模型目录(这里是tp1/pp4训练保存的模型):
Qwen2-VL-7B-Instruct/
├── latest_checkpointed_iteration.txt
└── release
├── mp_rank_00_000
│ └── model_optim_rng.pt
├── mp_rank_00_001
│ └── model_optim_rng.pt
├── mp_rank_00_002
│ └── model_optim_rng.pt
└── mp_rank_00_003
└── model_optim_rng.pt
"""
import re
from dataclasses import dataclass
from itertools import accumulate
from pathlib import Path
from typing import Callable, Any, List, Dict, Optional, Union, Tuple
import numpy as np
import torch
from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
from torch.distributed.checkpoint import FileSystemWriter
from safetensors.torch import load_file
from tqdm import tqdm
from checkpoint.common.constant import LATEST_TXT, MEGATRON_CKPT_NAME, IMAGE_ENCODER, AUDIO_ENCODER, TEXT_DECODER
from checkpoint.common.mm_types import STATE_DICT_T, VPP_LAYER_NUM_T
from checkpoint.vlm_model.config import ConvertVppMMConfig, ConvertTorchDCPConfig
from checkpoint.vlm_model.operator import Operator, TieOp, TP_PATTERN_T
CHECKPOINT_VERSION_KEY = "checkpoint_version"
CHECKPOINT_VERSION_VALUE = 3.0
@dataclass
class PPStageSchema:
"""When splitting different modules such as vit/lm/audio, the corresponding weight names are different,
and it is necessary to distinguish between the first and last layers and the middle layer
all_layer: The weights included in all layers
"""
firsts: List[str]
lasts: List[str]
middle: str
all_layer: List[str] = None
text_schema = PPStageSchema(
firsts=['text_decoder.embedding.'],
lasts=['text_decoder.decoder.final_layernorm.', 'text_decoder.output_layer.'],
middle='text_decoder.decoder.layers.'
)
vision_schema = PPStageSchema(
firsts=['image_encoder.encoder.patch_embed.'],
lasts=['image_encoder.projector.'],
middle='image_encoder.encoder.blocks.layers.'
)
audio_schema = PPStageSchema(
firsts=['audio_encoder.encoder.conv', 'audio_encoder.encoder.audio_bos_eos_token'],
lasts=['audio_encoder.encoder.proj', 'audio_encoder.encoder.ln_post'],
middle='audio_encoder.encoder.blocks.layers.'
)
@dataclass
class PPRange:
"""For each rank of the pp group, we need know which layers of transformers correspond to it
start. Each value in start defines the layer index at which the rank pp starts
end. Each value in 'end' defines the layer index at which the rank pp ends
Pp_first_rank. Defines the global pp_rank corresponding to the first layer of the transformer
Pp_1ast_rank. Defines the global pp_rank corresponding to the last layer of the transformer
"""
start: List[int]
end: List[int]
first_layer_rank: int
last_layer_rank: int
@property
def pp_size(self) -> int:
return len(self.start)
def partition_state_dict_by_pp(state_dict: STATE_DICT_T,
pp_ranges: List[PPRange],
stages: List[PPStageSchema]) -> List[STATE_DICT_T]:
"""For transformer structures of different modalities, use a universal PP splitting logic to split the
model parameter state-dict into different PP ranks and reset the corresponding layer numbers. Supports
hetero PP sizes with dp replication(e.g., VIT PP=1, Audio model PP=2, LLM PP =4)
"""
global_pp_size = max(r.pp_size for r in pp_ranges)
pp_weights = []
for pp_rank in range(global_pp_size):
pp_weight = {}
for weight_name, weight_value in state_dict.items():
for modality_stage, modality_pp_range in zip(stages, pp_ranges):
offset = pp_rank - modality_pp_range.first_layer_rank
if offset < 0:
continue
local_pp_idx = offset % modality_pp_range.pp_size
is_first_in_group = (local_pp_idx == modality_pp_range.first_layer_rank)
is_last_in_group = (local_pp_idx == modality_pp_range.last_layer_rank)
if is_first_in_group:
for name_start in modality_stage.firsts:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
if is_last_in_group:
for name_start in modality_stage.lasts:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
if weight_name.startswith(modality_stage.middle):
layer_start = modality_pp_range.start[local_pp_idx]
layer_end = modality_pp_range.end[local_pp_idx]
raw_layer_num, *remains = weight_name.replace(modality_stage.middle, "").split(".")
try:
raw_layer_num = int(raw_layer_num)
if layer_start <= raw_layer_num < layer_end:
new_layer_num = raw_layer_num - layer_start
new_weight_name = ".".join([modality_stage.middle[:-1], str(new_layer_num), *remains])
pp_weight[new_weight_name] = weight_value
except ValueError as e:
raise ValueError(
f"Failed to parse layer number from weight name: '{weight_name}'\n"
f"Modality: {modality_stage}, PP range: {modality_pp_range}\n"
f"Original error: {str(e)}"
) from e
if modality_stage.all_layer:
for name_start in modality_stage.all_layer:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
pp_weights.append(pp_weight)
return pp_weights
def save_by_vpp(state_dicts: List[Dict[str, torch.Tensor]],
save_root_dir: Path,
iteration: Optional[Union[str, int]] = 'release',
pp_and_vpp_size: Tuple[int, int] = (1, 1),
ep_size: int = 1,
tp_rank: int = 0,
ep_rank: int = 0):
"""获取pp_size和vpp_size"""
pp_size, vpp_size = pp_and_vpp_size
for pp_rank in tqdm(range(pp_size), desc="pp step"):
name_parts = ["mp", "rank", f"{tp_rank:02d}"]
if pp_size > 1:
name_parts.append(f"{pp_rank:03d}")
if ep_size > 1:
name_parts.append(f"{ep_rank:03d}")
iter_name = iteration if isinstance(iteration, str) else f"iter_{iteration:07d}"
save_path = save_root_dir.joinpath(iter_name, "_".join(name_parts))
save_path.mkdir(exist_ok=True, parents=True)
if vpp_size > 1:
save_dict = {f'model{vpp_idx}': state_dicts[vpp_idx * pp_size + pp_rank] for vpp_idx in range(vpp_size)}
else:
save_dict = {'model': state_dicts[pp_rank]}
save_dict[CHECKPOINT_VERSION_KEY] = CHECKPOINT_VERSION_VALUE
torch.save(save_dict, save_path.joinpath(MEGATRON_CKPT_NAME))
save_root_dir.joinpath(LATEST_TXT).write_text(str(iteration))
def save_by_dcp(state_dict: STATE_DICT_T,
save_root_dir: Path,
iteration: Union[str, int] = 'release'):
iter_name = iteration if isinstance(iteration, str) else f"iter_{iteration:07d}"
save_path = save_root_dir.joinpath(iter_name)
save_path.mkdir(exist_ok=True, parents=True)
save_dict = {
'model': state_dict,
'checkpoint_version': CHECKPOINT_VERSION_VALUE
}
_save_state_dict(
save_dict, storage_writer=FileSystemWriter(save_path), no_dist=True
)
save_root_dir.joinpath(LATEST_TXT).write_text(str(iteration))
def split_by_tp(state_dict: STATE_DICT_T, patterns: TP_PATTERN_T, tp_size: int = 0, vit_tp_size: int = 0,
audio_tp_size: int = 0) -> List[STATE_DICT_T]:
"""
将状态字典按 TP 并行度切分
:param state_dict: 原始状态字典
:param patterns: 匹配模式到切分类的映射
:param tp_size: 默认 TP 并行度
:param vit_tp_size: vit TP 并行度
:param audio_tp_size: audio TP 并行度
:return: 切分后的状态字典列表
"""
if tp_size == 1 and vit_tp_size <= 1 and audio_tp_size <= 1:
return [state_dict.copy()]
max_tp_size = max(tp_size, vit_tp_size, audio_tp_size)
if any(
size != 0 and max_tp_size % size != 0
for size in [tp_size, vit_tp_size, audio_tp_size]
):
raise ValueError('TP segmentation of multiple modules does not meet the requirements')
tp_dicts = [dict() for _ in range(max_tp_size)]
def assign_split_values(tar_splitter, tar_tp_dicts, tar_tp_size, tar_key, tar_value):
tar_split_values = tar_splitter.split(tar_tp_size, tar_value)
for i, tar_tp_dict in enumerate(tar_tp_dicts):
tar_tp_dict[tar_key] = tar_split_values[i % len(tar_split_values)].clone()
for key, value in state_dict.items():
for pattern, splitter in patterns.items():
if re.match(pattern, key):
if vit_tp_size != 0 and key.startswith(IMAGE_ENCODER):
assign_split_values(splitter, tp_dicts, vit_tp_size, key, value)
break
if audio_tp_size != 0 and key.startswith(AUDIO_ENCODER):
assign_split_values(splitter, tp_dicts, audio_tp_size, key, value)
break
if vit_tp_size == 0 or audio_tp_size == 0 or key.startswith(TEXT_DECODER):
assign_split_values(splitter, tp_dicts, tp_size, key, value)
break
else:
for tp_dict in tp_dicts:
tp_dict[key] = value.clone()
return tp_dicts
def split_by_ep(_state_dict: STATE_DICT_T, _ep_size: int = 1, _num_experts: int = 0) -> List[Dict[str, torch.Tensor]]:
if _ep_size == 1 or _num_experts == 0:
return [_state_dict]
per_ep_rank_experts = _num_experts // _ep_size
ep_state_dicts = []
for ep_rank in range(_ep_size):
tmp_state_dict = {}
for key, value in _state_dict.items():
if "local_experts" in key:
expert_idx = int(key.split(".")[
7])
if expert_idx >= ep_rank * per_ep_rank_experts and expert_idx < (ep_rank + 1) * per_ep_rank_experts:
local_expert_idx = expert_idx - ep_rank * per_ep_rank_experts
tmp_key_list = key.split(".")
tmp_key_list[7] = str(local_expert_idx)
new_key = ".".join(tmp_key_list)
tmp_state_dict[new_key] = value
else:
tmp_state_dict[key] = value
ep_state_dicts.append(tmp_state_dict)
return ep_state_dicts
def merge_llm_weights_to_state_dict(vl_state_dict: STATE_DICT_T, llm_state_dict: STATE_DICT_T) -> STATE_DICT_T:
for key in list(vl_state_dict.keys()):
if key.startswith('model') or key.startswith("visual.merger"):
vl_state_dict.pop(key)
vl_state_dict.update(llm_state_dict)
return vl_state_dict
def filter_vit_keys(_state_dict: STATE_DICT_T):
"""过滤掉llm相关的键,只保留vit部分的键"""
for key in list(_state_dict.keys()):
if not key.startswith("visual"):
_state_dict.pop(key)
def load_from_hf(hf_dir: Path, pt_path: Optional[Path] = None) -> STATE_DICT_T:
state_dict = {}
if pt_path:
weight = torch.load(pt_path)
state_dict.update(weight, device='cpu')
else:
files = list(hf_dir.glob("*.safetensors"))
for safe_path in files:
state_dict.update(load_file(str(safe_path), device='cpu'))
return state_dict
def merge_pp_index(vit_pipeline_num_layers: List[int], llm_pipeline_num_layers: List[int]) -> List[Tuple[int, int]]:
"""返回每张卡上vit和llm各自的层数"""
split_method = []
for vit_num, llm_num in zip(vit_pipeline_num_layers, llm_pipeline_num_layers):
split_method.append((vit_num, llm_num))
return split_method
def merge_vpp_index(vit_pipeline_num_layers: VPP_LAYER_NUM_T,
llm_pipeline_num_layers: VPP_LAYER_NUM_T,
audio_pipeline_num_layers: VPP_LAYER_NUM_T) -> List[PPRange]:
modalities_pp_range = []
for modality in [vit_pipeline_num_layers, llm_pipeline_num_layers, audio_pipeline_num_layers]:
modality_pp_flat = [item
for sublist in modality
for item in sublist]
if not modality_pp_flat:
continue
modality_pp_acc = list(accumulate(modality_pp_flat))
first_layer_rank, last_layer_rank = np.nonzero(np.array(modality_pp_flat))[0][[0, -1]]
modalities_pp_range.append(PPRange(start=[0] + modality_pp_acc[:-1],
end=modality_pp_acc,
first_layer_rank=first_layer_rank,
last_layer_rank=last_layer_rank))
return modalities_pp_range
def convert(state_dict: STATE_DICT_T, ops: List[Operator], is_tie: bool, is_pp: bool) -> STATE_DICT_T:
if is_tie and is_pp:
ops.append(TieOp(raw_name='text_decoder.embedding.word_embeddings.weight',
new_name='text_decoder.output_layer.weight'))
for op in ops:
op.apply(state_dict)
return state_dict
def convert_hf_to_mm(convert_config: ConvertVppMMConfig, ops: List[Operator], tp_patterns: Dict[str, Callable],
stages: List[PPStageSchema]):
pt_path = getattr(convert_config, 'pt_path', None)
parallel_config = convert_config.parallel_config
num_experts = convert_config.common_model_config.num_experts
state_dict = load_from_hf(convert_config.hf_config.hf_dir, pt_path)
if convert_config.common_model_config.llm_hf_dir is not None:
llm_state_dict = load_from_hf(convert_config.common_model_config.llm_hf_dir)
state_dict = merge_llm_weights_to_state_dict(state_dict, llm_state_dict)
if convert_config.save_vit_only:
filter_vit_keys(state_dict)
state_dict = convert(state_dict, ops, convert_config.common_model_config.tie_word_embeddings, parallel_config.is_pp())
if getattr(convert_config, 'save_lora_only', False):
state_dict = {k: v for k, v in state_dict.items() if "lora" in k}
ep_state_dicts = split_by_ep(state_dict, parallel_config.ep_size, _num_experts=num_experts)
ep_tp_state_dicts = []
for ep_state_dict in ep_state_dicts:
tp_state_dicts = split_by_tp(ep_state_dict, tp_patterns, parallel_config.tp_size, parallel_config.vit_tp_size, parallel_config.audio_tp_size)
ep_tp_state_dicts.append(tp_state_dicts)
pp_ranges = merge_vpp_index(parallel_config.vit_pp_layers,
parallel_config.llm_pp_layers,
parallel_config.audio_pp_layers or [[]])
for ep_rank, tp_state_dicts in enumerate(tqdm(ep_tp_state_dicts, desc="ep step")):
for tp_rank, tp_state_dict in enumerate(tqdm(tp_state_dicts, desc="tp step")):
pp_state_dicts = partition_state_dict_by_pp(tp_state_dict, pp_ranges, stages)
save_by_vpp(pp_state_dicts, convert_config.mm_dir,
pp_and_vpp_size=(parallel_config.pp_size, parallel_config.vpp_size),
ep_size=parallel_config.ep_size, ep_rank=ep_rank, tp_rank=tp_rank)
def convert_hf_to_mm_dcp(convert_config: ConvertTorchDCPConfig, ops: List[Operator]):
state_dict = load_from_hf(convert_config.hf_config.hf_dir)
if convert_config.common_model_config.llm_hf_dir is not None:
llm_state_dict = load_from_hf(convert_config.common_model_config.llm_hf_dir)
state_dict = merge_llm_weights_to_state_dict(state_dict, llm_state_dict)
state_dict = convert(state_dict, ops, convert_config.common_model_config.tie_word_embeddings, is_pp=False)
save_by_dcp(state_dict, convert_config.mm_dir)