from itertools import accumulate
from typing import List, Dict, Callable
import numpy as np
from tqdm import tqdm
from checkpoint.common.mm_types import STATE_DICT_T, VPP_LAYER_NUM_T
from checkpoint.vlm_model.config import ConvertVppMMConfig
from checkpoint.vlm_model.hf_to_mm import (
PPRange, PPStageSchema,
load_from_hf, merge_llm_weights_to_state_dict, filter_vit_keys,
convert, split_by_ep, split_by_tp, save_by_vpp,
)
from checkpoint.vlm_model.operator import Operator
def partition_state_dict_by_pp_ldt(state_dict: STATE_DICT_T,
pp_ranges: List[PPRange],
stages: List[PPStageSchema]) -> List[STATE_DICT_T]:
global_pp_size = max(r.pp_size for r in pp_ranges)
pp_weights = []
for pp_rank in range(global_pp_size):
pp_weight = {}
for weight_name, weight_value in state_dict.items():
for modality_stage, modality_pp_range in zip(stages, pp_ranges):
if pp_rank >= modality_pp_range.pp_size:
continue
is_first_in_group = (pp_rank == modality_pp_range.first_layer_rank)
is_last_in_group = (pp_rank == modality_pp_range.last_layer_rank)
if is_first_in_group:
for name_start in modality_stage.firsts:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
if is_last_in_group:
for name_start in modality_stage.lasts:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
if weight_name.startswith(modality_stage.middle):
layer_start = modality_pp_range.start[pp_rank]
layer_end = modality_pp_range.end[pp_rank]
raw_layer_num, *remains = weight_name.replace(modality_stage.middle, "").split(".")
try:
raw_layer_num = int(raw_layer_num)
if layer_start <= raw_layer_num < layer_end:
new_layer_num = raw_layer_num - layer_start
new_weight_name = ".".join([modality_stage.middle[:-1], str(new_layer_num), *remains])
pp_weight[new_weight_name] = weight_value
except ValueError as e:
raise ValueError(
f"Failed to parse layer number from weight name: '{weight_name}'\n"
f"Modality: {modality_stage}, PP range: {modality_pp_range}\n"
f"Original error: {str(e)}"
) from e
if modality_stage.all_layer:
has_layers = modality_pp_range.start[pp_rank] < modality_pp_range.end[pp_rank]
if has_layers or is_first_in_group or is_last_in_group:
for name_start in modality_stage.all_layer:
if weight_name.startswith(name_start):
pp_weight[weight_name] = weight_value
pp_weights.append(pp_weight)
return pp_weights
def merge_vpp_index_ldt(vit_pipeline_num_layers: VPP_LAYER_NUM_T,
llm_pipeline_num_layers: VPP_LAYER_NUM_T,
audio_pipeline_num_layers: VPP_LAYER_NUM_T) -> List[PPRange]:
modalities_pp_range = []
modalities = [vit_pipeline_num_layers, llm_pipeline_num_layers, audio_pipeline_num_layers]
is_llm_flags = [False, True, False]
for modality, is_llm in zip(modalities, is_llm_flags):
modality_pp_flat = [item
for sublist in modality
for item in sublist]
if not modality_pp_flat:
continue
modality_pp_acc = list(accumulate(modality_pp_flat))
if is_llm:
pp_size = len(modality[0])
vpp_size = len(modality)
first_vpp_idx = next(i for i in range(vpp_size) if sum(modality[i]) > 0)
first_layer_rank = first_vpp_idx * pp_size
last_layer_rank = (vpp_size - 1) * pp_size
else:
first_layer_rank, last_layer_rank = np.nonzero(np.array(modality_pp_flat))[0][[0, -1]]
modalities_pp_range.append(PPRange(start=[0] + modality_pp_acc[:-1],
end=modality_pp_acc,
first_layer_rank=first_layer_rank,
last_layer_rank=last_layer_rank))
return modalities_pp_range
def convert_hf_to_mm_ldt(convert_config: ConvertVppMMConfig, ops: List[Operator],
tp_patterns: Dict[str, Callable],
stages: List[PPStageSchema]):
"""LDT版convert_hf_to_mm,使用LDT变体的PP切分和VPP索引逻辑。"""
pt_path = getattr(convert_config, 'pt_path', None)
parallel_config = convert_config.parallel_config
num_experts = convert_config.common_model_config.num_experts
state_dict = load_from_hf(convert_config.hf_config.hf_dir, pt_path)
if convert_config.common_model_config.llm_hf_dir is not None:
llm_state_dict = load_from_hf(convert_config.common_model_config.llm_hf_dir)
state_dict = merge_llm_weights_to_state_dict(state_dict, llm_state_dict)
if convert_config.save_vit_only:
filter_vit_keys(state_dict)
state_dict = convert(state_dict, ops, convert_config.common_model_config.tie_word_embeddings, parallel_config.is_pp())
if getattr(convert_config, 'save_lora_only', False):
state_dict = {k: v for k, v in state_dict.items() if "lora" in k}
ep_state_dicts = split_by_ep(state_dict, parallel_config.ep_size, _num_experts=num_experts)
ep_tp_state_dicts = []
for ep_state_dict in ep_state_dicts:
tp_state_dicts = split_by_tp(ep_state_dict, tp_patterns, parallel_config.tp_size,
parallel_config.vit_tp_size, parallel_config.audio_tp_size)
ep_tp_state_dicts.append(tp_state_dicts)
pp_ranges = merge_vpp_index_ldt(parallel_config.vit_pp_layers,
parallel_config.llm_pp_layers,
parallel_config.audio_pp_layers or [[]])
for ep_rank, tp_state_dicts in enumerate(tqdm(ep_tp_state_dicts, desc="ep step")):
for tp_rank, tp_state_dict in enumerate(tqdm(tp_state_dicts, desc="tp step")):
pp_state_dicts = partition_state_dict_by_pp_ldt(tp_state_dict, pp_ranges, stages)
save_by_vpp(pp_state_dicts, convert_config.mm_dir,
pp_and_vpp_size=(parallel_config.pp_size, parallel_config.vpp_size),
ep_size=parallel_config.ep_size, ep_rank=ep_rank, tp_rank=tp_rank)