from typing import List, cast
from checkpoint.common.converter import Converter
from checkpoint.common.permissions import set_directory_permissions
from checkpoint.vlm_model.config import ConvertVppMMConfig, ConvertHFConfig, ConvertResplitConfig
from checkpoint.vlm_model.hf_to_mm import PPStageSchema, text_schema, convert_hf_to_mm
from checkpoint.vlm_model.operator import (
ExpertUpGateMergeOp, Operator, UpGateMergeOp, QKVDirectMergeOp, RenameOp, GLUSplit,
ColSplit, RowSplit
)
def create_deepseek_vl_ops() -> List[Operator]:
"""deepseekvl权重转换逻辑"""
ops = [
RenameOp(
(
(r'vision.attn_pool.kv.bias', r'image_encoder.encoder.attn_pool.kv.bias'),
(r'vision.attn_pool.kv.weight', r'image_encoder.encoder.attn_pool.kv.weight'),
(r'vision.attn_pool.latent', r'image_encoder.encoder.attn_pool.latent'),
(r'vision.attn_pool.mlp.fc1.bias', r'image_encoder.encoder.attn_pool.mlp.fc1.bias'),
(r'vision.attn_pool.mlp.fc1.weight', r'image_encoder.encoder.attn_pool.mlp.fc1.weight'),
(r'vision.attn_pool.mlp.fc2.bias', r'image_encoder.encoder.attn_pool.mlp.fc2.bias'),
(r'vision.attn_pool.mlp.fc2.weight', r'image_encoder.encoder.attn_pool.mlp.fc2.weight'),
(r'vision.attn_pool.norm.bias', r'image_encoder.encoder.attn_pool.norm.bias'),
(r'vision.attn_pool.norm.weight', r'image_encoder.encoder.attn_pool.norm.weight'),
(r'vision.attn_pool.proj.bias', r'image_encoder.encoder.attn_pool.proj.bias'),
(r'vision.attn_pool.proj.weight', r'image_encoder.encoder.attn_pool.proj.weight'),
(r'vision.attn_pool.q.bias', r'image_encoder.encoder.attn_pool.q.bias'),
(r'vision.attn_pool.q.weight', r'image_encoder.encoder.attn_pool.q.weight'),
(r'vision.blocks.(\d+).attn.proj.bias', r'image_encoder.encoder.blocks.(\d+).attn.proj.bias'),
(r'vision.blocks.(\d+).attn.proj.weight', r'image_encoder.encoder.blocks.(\d+).attn.proj.weight'),
(r'vision.blocks.(\d+).attn.qkv.bias', r'image_encoder.encoder.blocks.(\d+).attn.qkv.bias'),
(r'vision.blocks.(\d+).attn.qkv.weight', r'image_encoder.encoder.blocks.(\d+).attn.qkv.weight'),
(r'vision.blocks.(\d+).mlp.fc1.bias', r'image_encoder.encoder.blocks.(\d+).mlp.fc1.bias'),
(r'vision.blocks.(\d+).mlp.fc1.weight', r'image_encoder.encoder.blocks.(\d+).mlp.fc1.weight'),
(r'vision.blocks.(\d+).mlp.fc2.bias', r'image_encoder.encoder.blocks.(\d+).mlp.fc2.bias'),
(r'vision.blocks.(\d+).mlp.fc2.weight', r'image_encoder.encoder.blocks.(\d+).mlp.fc2.weight'),
(r'vision.blocks.(\d+).norm1.bias', r'image_encoder.encoder.blocks.(\d+).norm1.bias'),
(r'vision.blocks.(\d+).norm1.weight', r'image_encoder.encoder.blocks.(\d+).norm1.weight'),
(r'vision.blocks.(\d+).norm2.bias', r'image_encoder.encoder.blocks.(\d+).norm2.bias'),
(r'vision.blocks.(\d+).norm2.weight', r'image_encoder.encoder.blocks.(\d+).norm2.weight'),
(r'vision.norm.bias', r'image_encoder.encoder.norm.bias'),
(r'vision.norm.weight', r'image_encoder.encoder.norm.weight'),
(r'vision.patch_embed.proj.bias', r'image_encoder.encoder.patch_embed.proj.bias'),
(r'vision.patch_embed.proj.weight', r'image_encoder.encoder.patch_embed.proj.weight'),
(r'vision.pos_embed', r'image_encoder.encoder.pos_embed'),
(r'projector.layers.(\d+).bias', r'image_encoder.projector.layers.(\d+).bias'),
(r'projector.layers.(\d+).weight', r'image_encoder.projector.layers.(\d+).weight'),
(r'language.lm_head.weight', r'text_decoder.output_layer.weight'),
(r'language.model.embed_tokens.weight', r'text_decoder.embedding.word_embeddings.weight'),
(r'language.model.layers.(\d+).input_layernorm.weight',
r'text_decoder.decoder.layers.(\d+).input_layernorm.weight'),
(r'language.model.layers.(\d+).mlp.down_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.linear_fc2.weight'),
(r'language.model.layers.(\d+).mlp.experts.(\d+).down_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc2.weight'),
(r'language.model.layers.(\d+).mlp.experts.(\d+).gate_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1_gate.weight'),
(r'language.model.layers.(\d+).mlp.experts.(\d+).up_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1_up.weight'),
(r'language.model.layers.(\d+).mlp.gate.e_score_correction_bias',
r'text_decoder.decoder.layers.(\d+).mlp.router.expert_bias'),
(r'language.model.layers.(\d+).mlp.gate.weight',
r'text_decoder.decoder.layers.(\d+).mlp.router.weight'),
(r'language.model.layers.(\d+).mlp.gate_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.linear_fc1_gate.weight'),
(r'language.model.layers.(\d+).mlp.shared_experts.down_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc2.weight'),
(r'language.model.layers.(\d+).mlp.shared_experts.gate_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1_gate.weight'),
(r'language.model.layers.(\d+).mlp.shared_experts.up_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1_up.weight'),
(r'language.model.layers.(\d+).mlp.up_proj.weight',
r'text_decoder.decoder.layers.(\d+).mlp.linear_fc1_up.weight'),
(r'language.model.layers.(\d+).post_attention_layernorm.weight',
r'text_decoder.decoder.layers.(\d+).pre_mlp_layernorm.weight'),
(r'language.model.layers.(\d+).self_attn.kv_a_layernorm.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.k_layernorm.weight'),
(r'language.model.layers.(\d+).self_attn.kv_a_proj_with_mqa.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.kv_a_proj_with_mqa.weight'),
(r'language.model.layers.(\d+).self_attn.kv_b_proj.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.linear_kvb.weight'),
(r'language.model.layers.(\d+).self_attn.o_proj.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.linear_proj.weight'),
(r'language.model.layers.(\d+).self_attn.q_proj.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.q_proj.weight'),
(r'language.model.norm.weight', r'text_decoder.decoder.final_layernorm.weight'),
(r'language.model.layers.(\d+).self_attn.q_a_proj.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.q_proj.weight'),
(r'language.model.layers.(\d+).self_attn.q_b_proj.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.linear_qb.weight'),
(r'language.model.layers.(\d+).self_attn.q_a_layernorm.weight',
r'text_decoder.decoder.layers.(\d+).self_attention.q_layernorm.weight'),
)
),
UpGateMergeOp(raw_names=[r"text_decoder.decoder.layers.(\d+).mlp.linear_fc1_gate.weight",
r"text_decoder.decoder.layers.(\d+).mlp.linear_fc1_up.weight"],
new_name=r"text_decoder.decoder.layers.(\d+).mlp.linear_fc1.weight"),
UpGateMergeOp(raw_names=[r"text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1_gate.weight",
r"text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1_up.weight"],
new_name=r"text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1.weight"),
ExpertUpGateMergeOp(
raw_names=[r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1_gate.weight",
r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1_up.weight"],
new_name=r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1.weight"
),
QKVDirectMergeOp(raw_names=(r"text_decoder.decoder.layers.(\d+).self_attention.q_proj.weight",
r"text_decoder.decoder.layers.(\d+).self_attention.kv_a_proj_with_mqa.weight"),
new_name=r"text_decoder.decoder.layers.(\d+).self_attention.linear_qkv.weight")
]
return ops
deepseek_vl_tp_patterns = {
r"text_decoder.output_layer.weight": RowSplit,
r"text_decoder.embedding.word_embeddings.weight": RowSplit,
r"text_decoder.decoder.layers.(\d+).mlp.linear_fc1.weight": GLUSplit,
r"text_decoder.decoder.layers.(\d+).mlp.linear_fc2.weight": ColSplit,
r"text_decoder.decoder.layers.(\d+).self_attention.linear_qb.weight": RowSplit,
r"text_decoder.decoder.layers.(\d+).self_attention.linear_kvb.weight": RowSplit,
r"text_decoder.decoder.layers.(\d+).self_attention.linear_kvb.bias": RowSplit,
r"text_decoder.decoder.layers.(\d+).self_attention.linear_proj.weight": ColSplit,
r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1.weight": GLUSplit,
r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc2.weight": ColSplit,
r"text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc1.weight": GLUSplit,
r"text_decoder.decoder.layers.(\d+).mlp.shared_experts.linear_fc2.weight": ColSplit
}
vision_schema = PPStageSchema(
firsts=['image_encoder.encoder.patch_embed.', 'image_encoder.encoder.pos_embed'],
lasts=['image_encoder.encoder.norm', 'image_encoder.encoder.attn_pool', 'image_encoder.projector.'],
middle='image_encoder.encoder.blocks.',
all_layer=['image_newline', 'view_seperator']
)
class ConvertVppMMConfigDeepseekVl2(ConvertVppMMConfig):
def model_post_init(self, _context):
from deepseek_vl2.models.modeling_deepseek_vl_v2 import DeepseekVLV2Config
config = cast(DeepseekVLV2Config, self.hf_config.config)
self.common_model_config.num_key_value_heads = config.language_config.num_key_value_heads
self.common_model_config.vit_num_layers = config.vision_config.layers
self.common_model_config.llm_num_layers = config.language_config.num_hidden_layers
self.common_model_config.num_experts = config.language_config.n_routed_experts
self.common_model_config.tie_word_embeddings = config.language_config.tie_word_embeddings
class DeepSeekVLConverter(Converter):
"""DeepSeekVL模型转换工具"""
@staticmethod
def _create_ops() -> List[Operator]:
ops = create_deepseek_vl_ops()
return ops
@staticmethod
def hf_to_mm(cfg: ConvertVppMMConfigDeepseekVl2):
"""huggingface模型转换mindspeed-mm模型权重"""
ops = DeepSeekVLConverter._create_ops()
convert_hf_to_mm(cfg, ops, deepseek_vl_tp_patterns, [vision_schema, text_schema])
set_directory_permissions(cfg.mm_dir)
@staticmethod
def mm_to_hf(cfg: ConvertHFConfig):
"""mindspeed-mm模型转换huggingface模型权重"""
pass
@staticmethod
def resplit(cfg: ConvertResplitConfig):
"""mindspeed-mm模型权重重新切分"""
pass