import json
import os
import re
import shutil
from pathlib import Path
from typing import List
import torch
from safetensors.torch import save_file
from checkpoint.common.constant import LATEST_TXT, MEGATRON_CKPT_NAME, IMAGE_ENCODER, AUDIO_ENCODER, TEXT_DECODER, \
LORA_CKPT_NAME
from checkpoint.common.mm_types import STATE_DICT_T, PP_LAYER_NUM_T
from checkpoint.vlm_model.config import ConvertHFConfig
from checkpoint.vlm_model.hf_to_mm import load_from_hf
from checkpoint.vlm_model.operator import Operator, TP_PATTERN_T
def save_by_index_json(_state_dicts, _save_dir):
metadata = {
'format': 'pt'
}
for index, state_dict in enumerate(_state_dicts, start=1):
name = f'model-{index:05}-of-{len(_state_dicts):05}.safetensors'
save_file(state_dict, Path(_save_dir).joinpath(name), metadata=metadata)
def save_safetensors(_state_dicts, _save_dir):
Path(_save_dir).mkdir(parents=True, exist_ok=True)
metadata = {
'format': 'pt'
}
save_file(_state_dicts, Path(_save_dir).joinpath(LORA_CKPT_NAME), metadata=metadata)
def split_by_index_json(state_dict: STATE_DICT_T, hf_dir: Path) -> List[STATE_DICT_T]:
index_json_path = hf_dir.joinpath('model.safetensors.index.json')
if not os.path.exists(index_json_path):
raise ValueError(f"safetensors.index.json not in {index_json_path}")
return_dicts = []
weight_map = json.loads(index_json_path.read_text()).get('weight_map', {})
for key, value in weight_map.items():
index = int(value.split('-')[1])
while index > len(return_dicts):
return_dicts.append({})
return_dicts[index - 1][key] = state_dict[key]
return return_dicts
def copy_files_except_suffix(source_path: Path, target_path: Path, except_suffix: str = '.safetensors'):
"""拷贝源路径下除了以except_suffix为后缀的其他所有文件到目标路径,包含子目录"""
target_path.mkdir(parents=True, exist_ok=True)
for item in source_path.rglob('*'):
if item.is_file() and item.suffix != except_suffix:
relative_path = item.relative_to(source_path)
destination = target_path / relative_path
destination.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(item, destination)
print(f"Copied: {item} -> {destination}")
def load_from_mm(load_dir: Path,
vit_pp_list: PP_LAYER_NUM_T,
llm_pp_list: PP_LAYER_NUM_T,
tp_size: int = 1,
audio_pp_list: PP_LAYER_NUM_T = None,
ep_size: int = 1,
num_experts: int = 1) -> List[STATE_DICT_T]:
import mindspeed.megatron_adaptor
save_iteration = load_dir.joinpath(LATEST_TXT).read_text()
save_dir = load_dir.joinpath(f"iter_{int(save_iteration):07}" if save_iteration != "release" else save_iteration)
global_pp_size = max(
len(vit_pp_list),
len(llm_pp_list),
len(audio_pp_list) if audio_pp_list else 0
)
state_dicts = []
for tp_rank in range(tp_size):
pp_state_dict = {}
for pp_rank in range(global_pp_size):
if ep_size > 1:
for ep_rank in range(ep_size):
if global_pp_size > 1:
current_path = save_dir.joinpath(f"mp_rank_{int(tp_rank):02}_{int(pp_rank):03}_{int(ep_rank):03}")
else:
current_path = save_dir.joinpath(f"mp_rank_{int(tp_rank):02}_{int(ep_rank):03}")
pt_path = current_path.joinpath(MEGATRON_CKPT_NAME)
dict_ep = {}
for param, tensor in torch.load(pt_path, map_location='cpu', weights_only=False)['model'].items():
if tensor is not None:
new_key = rename_pp_ep_parameter(param, vit_pp_list, llm_pp_list, audio_pp_list, pp_rank, ep_rank, ep_size, num_experts)
dict_ep.update({new_key: tensor})
pp_state_dict.update(dict_ep)
else:
if global_pp_size > 1:
current_path = save_dir.joinpath(f"mp_rank_{int(tp_rank):02}_{int(pp_rank):03}")
else:
current_path = save_dir.joinpath(f"mp_rank_{int(tp_rank):02}")
pt_path = current_path.joinpath(MEGATRON_CKPT_NAME)
print(str(pt_path).center(100, '_'))
pp_state_dict.update(
{rename_pp_parameter(param, vit_pp_list, llm_pp_list, audio_pp_list, pp_rank): tensor
for param, tensor in torch.load(pt_path, map_location='cpu', weights_only=False)['model'].items()
if tensor is not None})
state_dicts.append(pp_state_dict)
return state_dicts
def merge_by_tp(tp_state_dicts: List[STATE_DICT_T], patterns: TP_PATTERN_T, tp_size: int = 0, vit_tp_size: int = 0,
audio_tp_size: int = 0) -> STATE_DICT_T:
"""将多个TP分片的权重合并回完整权重"""
if not tp_state_dicts:
return {}
merged_dict = {}
max_tp_size = len(tp_state_dicts)
if max_tp_size == 1:
return tp_state_dicts[0]
tp_config = {
IMAGE_ENCODER: vit_tp_size,
AUDIO_ENCODER: audio_tp_size,
TEXT_DECODER: tp_size
}
for key in tp_state_dicts[0].keys():
tp_values = [sd[key] for sd in tp_state_dicts]
for pattern, merger in patterns.items():
if re.match(pattern, key):
for prefix, size in tp_config.items():
if key.startswith(prefix):
if size <= 0:
merged_dict[key] = merger.merge(tp_values)
elif size == 1:
merged_dict[key] = tp_values[0]
else:
merged_dict[key] = merger.merge(tp_values[:size])
break
break
else:
merged_dict[key] = tp_values[0]
return merged_dict
def rename_pp_ep_parameter(param_name: str,
vit_pp_list: List[int],
llm_pp_list: List[int],
audio_pp_list: List[int] = None,
pp_index: int = 0,
ep_rank: int = 0,
ep_size: int = 1,
num_experts: int = 16) -> str:
pp_key = rename_pp_parameter(param_name, vit_pp_list, llm_pp_list, audio_pp_list, pp_index)
per_ep_rank_experts = num_experts // ep_size
offset = ep_rank * per_ep_rank_experts
if "local_experts" in pp_key:
parts = pp_key.split(".")
if len(parts) < 8:
raise ValueError(f"Invalid key format: {pp_key}")
local_expert_idx = int(parts[7])
original_expert_idx = offset + local_expert_idx
parts[7] = str(original_expert_idx)
new_key = ".".join(parts)
else:
new_key = pp_key
return new_key
def rename_pp_parameter(param_name: str,
vit_pp_list: List[int],
llm_pp_list: List[int],
audio_pp_list: List[int] = None,
pp_index: int = 0) -> str:
def compute_offset(pp_list: List[int], idx: int) -> int:
if not pp_list:
return 0
effective_idx = idx % len(pp_list)
return sum(pp_list[:effective_idx]) if effective_idx > 0 else 0
vit_offset = compute_offset(vit_pp_list, pp_index)
llm_offset = compute_offset(llm_pp_list, pp_index)
audio_offset = compute_offset(audio_pp_list, pp_index) if audio_pp_list is not None else 0
patterns = [
(r'^image_encoder\.encoder\.blocks\.layers\.(\d+)', vit_offset),
(r'^image_encoder\.encoder\.encoder\.layers\.(\d+)', vit_offset),
(r'^text_decoder\.decoder\.layers\.(\d+)', llm_offset),
(r'^audio_encoder\.encoder\.blocks\.layers\.(\d+)', audio_offset)
]
for pattern, offset in patterns:
match = re.match(pattern, param_name)
if match:
layer_num = int(match.group(1))
new_layer_num = offset + layer_num
return re.sub(r'\.\d+', f'.{new_layer_num}', param_name, count=1)
return param_name
def convert_mm_to_hf(convert_config: ConvertHFConfig,
ops: List[Operator],
tp_patterns: TP_PATTERN_T,
merge_source: bool = False):
parallel_config = convert_config.parallel_config
config = convert_config.hf_config.config
max_tp_size = max(parallel_config.tp_size, parallel_config.vit_tp_size, parallel_config.audio_tp_size)
ep_size = parallel_config.ep_size if hasattr(parallel_config, 'ep_size') else 1
if not hasattr(config, 'text_config'):
num_experts = 1
else:
num_experts = config.text_config.num_experts if hasattr(config.text_config, 'num_experts') else 1
state_dicts = load_from_mm(convert_config.mm_dir, parallel_config.vit_pp_layers, parallel_config.llm_pp_layers,
max_tp_size, parallel_config.audio_pp_layers, ep_size, num_experts)
state_dict = merge_by_tp(state_dicts, tp_patterns, parallel_config.tp_size, parallel_config.vit_tp_size, parallel_config.audio_tp_size)
for op in ops:
op.revert(state_dict)
if merge_source:
state_dict = {**load_from_hf(convert_config.hf_config.hf_dir), **state_dict}
state_dicts = split_by_index_json(state_dict, convert_config.hf_config.hf_dir)
copy_files_except_suffix(convert_config.hf_config.hf_dir, convert_config.save_hf_dir)
save_by_index_json(state_dicts, convert_config.save_hf_dir)
def convert_lora_mm_to_hf(convert_config: ConvertHFConfig,
ops: List[Operator],
tp_patterns: TP_PATTERN_T):
parallel_config = convert_config.parallel_config
max_tp_size = max(parallel_config.tp_size, parallel_config.vit_tp_size, parallel_config.audio_tp_size)
state_dicts = load_from_mm(convert_config.mm_dir, parallel_config.vit_pp_layers, parallel_config.llm_pp_layers,
max_tp_size, parallel_config.audio_pp_layers)
state_dict = merge_by_tp(state_dicts, tp_patterns)
for op in ops:
op.revert(state_dict)
save_safetensors(state_dict, convert_config.save_hf_dir)