from typing import cast, List



from checkpoint.common.converter import Converter

from checkpoint.common.permissions import set_directory_permissions

from checkpoint.vlm_model import hf_to_mm, mm_to_hf

from checkpoint.vlm_model.config import ConvertVppMMConfig, ConvertHFConfig

from checkpoint.vlm_model.hf_to_mm import PPStageSchema, text_schema

from checkpoint.vlm_model.operator import Operator, RenameOp, ExpertUpGateMergeOp, GLUSplit, ColSplit, RowSplit, QKVMergeOp, RelocateOp, ExpertSplitOp, UpGateMergeOp





vision_schema = PPStageSchema(

    firsts=['image_encoder.encoder.patch_embed.', 'image_encoder.encoder.pos_embed.'],

    lasts=['image_encoder.projector.'],

    middle='image_encoder.encoder.blocks.layers.'

)





def create_qwen3_vl_ops(vit_embed_dim: int, vit_num_heads: int, llm_num_query_groups: int, llm_q_size: int,

                        llm_kv_size: int, num_hidden_layers: int, num_experts: int, deepstack_visual_indexes: int) -> List[Operator]:

    ops = [

              RenameOp(

                  (

                      (r'model.visual.blocks.(\d+).norm1.bias', r'image_encoder.encoder.blocks.layers.(\d+).input_layernorm.bias'),

                      (r'model.visual.blocks.(\d+).norm1.weight', r'image_encoder.encoder.blocks.layers.(\d+).input_layernorm.weight'),

                      (r'model.visual.blocks.(\d+).norm2.bias', r'image_encoder.encoder.blocks.layers.(\d+).pre_mlp_layernorm.bias'),

                      (r'model.visual.blocks.(\d+).norm2.weight',

                       r'image_encoder.encoder.blocks.layers.(\d+).pre_mlp_layernorm.weight'),

                      (r'model.visual.blocks.(\d+).attn.qkv.weight',

                       r'image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.weight'),

                      (r'model.visual.blocks.(\d+).attn.qkv.bias',

                       r'image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.bias'),

                      (r'model.visual.patch_embed.proj', r'image_encoder.encoder.patch_embed.proj'),

                      (r'model.visual.blocks.(\d+).attn.proj', r'image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_proj'),

                      (r'model.visual.blocks.(\d+).mlp.linear_fc', r'image_encoder.encoder.blocks.layers.(\d+).mlp.linear_fc'),

                      (r'model.visual.merger.linear_fc', r'image_encoder.projector.encoder.linear_fc'),

                      (r'model.visual.merger.norm', r'image_encoder.projector.layernorm'),

                      (r'model.visual.pos_embed', r'image_encoder.encoder.pos_embed'),



                      (r'model.language_model.layers.(\d+).mlp.gate.weight', r'text_decoder.decoder.layers.(\d+).mlp.router.weight'),

                      (r'model.language_model.layers.(\d+).self_attn.q_norm.weight',

                       r'text_decoder.decoder.layers.(\d+).self_attention.q_layernorm.weight'),

                      (r'model.language_model.layers.(\d+).self_attn.k_norm.weight',

                       r'text_decoder.decoder.layers.(\d+).self_attention.k_layernorm.weight'),

                      (r'model.language_model.layers.(\d+).self_attn.o_proj.weight',

                       r'text_decoder.decoder.layers.(\d+).self_attention.linear_proj.weight'),

                      (r'model.language_model.layers.(\d+).input_layernorm', r'text_decoder.decoder.layers.(\d+).input_layernorm'),

                      (r'model.language_model.layers.(\d+).post_attention_layernorm',

                          r'text_decoder.decoder.layers.(\d+).pre_mlp_layernorm'),

                      (r'model.language_model.embed_tokens.weight',

                       r'text_decoder.embedding.word_embeddings.weight'),

                      (r'model.language_model.norm.weight',

                       r'text_decoder.decoder.final_layernorm.weight'),

                      (r'lm_head', r'text_decoder.output_layer')

                  )

              ),

              QKVMergeOp(raw_names=(r"model.language_model.layers.(\d+).self_attn.q_proj.weight",

                                    r"model.language_model.layers.(\d+).self_attn.k_proj.weight",

                                    r"model.language_model.layers.(\d+).self_attn.v_proj.weight"),

                         new_name=r"text_decoder.decoder.layers.(\d+).self_attention.linear_qkv.weight",

                         group=llm_num_query_groups,

                         q_size=llm_q_size,

                         k_size=llm_kv_size,

                         v_size=llm_kv_size,

                         ),

              RelocateOp(name=r"image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.weight",

                         new_name=r"model.visual.blocks.(\d+).attn.qkv.weight",

                         group=vit_num_heads,

                         split_size=[vit_embed_dim] * 3,  # vit的qkv不是gqa,所以切分的三份是相同的

                         ),

              RelocateOp(name=r"image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.bias",

                         new_name=r"model.visual.blocks.(\d+).attn.qkv.bias",

                         group=vit_num_heads,

                         split_size=[vit_embed_dim] * 3,  # vit的qkv不是gqa,所以切分的三份是相同的

                         ),

          ]

    expert_split_ops = [

                        ExpertSplitOp(raw_name=rf"model.language_model.layers.{idx}.mlp.experts.gate_up_proj",

                                      new_name=rf"text_decoder.decoder.layers.{idx}.mlp.experts.local_experts.(\d+).linear_fc1.weight",

                                      num_experts=num_experts) for idx in range(num_hidden_layers)

                       ] + \

                       [

                        ExpertSplitOp(raw_name=rf"model.language_model.layers.{idx}.mlp.experts.down_proj",

                                      new_name=rf"text_decoder.decoder.layers.{idx}.mlp.experts.local_experts.(\d+).linear_fc2.weight",

                                      num_experts=num_experts) for idx in range(num_hidden_layers)

                       ]

    deepstack_rename_op = [RenameOp(

                                (

                                    (rf'model.visual.deepstack_merger_list.{idx}.linear_fc',

                                    rf'image_encoder.encoder.blocks.layers.{deepstack_visual_indexes[idx]}.deepstack_layer.encoder.linear_fc'),

                                    (rf'model.visual.deepstack_merger_list.{idx}.norm',

                                    rf'image_encoder.encoder.blocks.layers.{deepstack_visual_indexes[idx]}.deepstack_layer.layernorm')

                                )

                            ) for idx in range(len(deepstack_visual_indexes))

                          ]

    dense_merge_op = [

                        UpGateMergeOp(

                            raw_names=[r"model.language_model.layers.(\d+).mlp.gate_proj.weight", r"model.language_model.layers.(\d+).mlp.up_proj.weight"],

                            new_name=r"text_decoder.decoder.layers.(\d+).mlp.linear_fc1.weight")

                     ]

    dense_rename_op = [

                        RenameOp(

                            (

                                (r"model.language_model.layers.(\d+).mlp.down_proj", r"text_decoder.decoder.layers.(\d+).mlp.linear_fc2"),



                            )

                        )

                      ]

    return ops + expert_split_ops + deepstack_rename_op + dense_merge_op + dense_rename_op





qwen3_vl_tp_patterns = {

    **{

        r"text_decoder.output_layer.weight": RowSplit,

        r"text_decoder.embedding.word_embeddings.weight": RowSplit,

        r'text_decoder.decoder.layers.(\d+).mlp.linear_fc1.weight': GLUSplit,

        r'text_decoder.decoder.layers.(\d+).mlp.linear_fc2.weight': ColSplit,

        r'text_decoder.decoder.layers.(\d+).self_attention.linear_qkv.weight': RowSplit,

        r'text_decoder.decoder.layers.(\d+).self_attention.linear_qkv.bias': RowSplit,

        r'text_decoder.decoder.layers.(\d+).self_attention.linear_proj.weight': ColSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_proj.weight": ColSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.bias": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).self_attention.linear_qkv.weight": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).mlp.linear_fc1.bias": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).mlp.linear_fc1.weight": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).mlp.linear_fc2.weight": ColSplit,

        r"image_encoder.projector.encoder.linear_fc1.bias": RowSplit,

        r"image_encoder.projector.encoder.linear_fc1.weight": RowSplit,

        r"image_encoder.projector.encoder.linear_fc2.weight": ColSplit,

        r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc1.weight": GLUSplit,

        r"text_decoder.decoder.layers.(\d+).mlp.experts.local_experts.(\d+).linear_fc2.weight": ColSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).deepstack_layer.encoder.linear_fc1.weight": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).deepstack_layer.encoder.linear_fc1.bias": RowSplit,

        r"image_encoder.encoder.blocks.layers.(\d+).deepstack_layer.encoder.linear_fc2.weight": ColSplit,

    }

}





class ConvertVppMMConfigQwen3(ConvertVppMMConfig):



    def model_post_init(self, _context):

        from transformers.models.qwen3_vl_moe import Qwen3VLMoeConfig

        config = cast(Qwen3VLMoeConfig, self.hf_config.config)



        self.common_model_config.num_key_value_heads = config.text_config.num_key_value_heads

        self.common_model_config.llm_num_layers = config.text_config.num_hidden_layers

        self.common_model_config.vit_num_layers = config.vision_config.depth

        self.common_model_config.num_experts = config.text_config.num_experts if hasattr(config.text_config, 'num_experts') else 0

        self.common_model_config.tie_word_embeddings = config.tie_word_embeddings





class Qwen3VLMegatronConverter(Converter):

    """Qwen3VL模型转换工具"""



    @staticmethod

    def _create_ops(config) -> List[Operator]:

        from transformers.models.qwen3_vl import Qwen3VLConfig

        config = cast(Qwen3VLConfig, config)

        num_key_value_heads = config.text_config.num_key_value_heads

        llm_head_hidden_size = config.text_config.head_dim if config.text_config.head_dim is not None \

            else config.text_config.hidden_size // config.text_config.num_attention_heads

        llm_q_size = llm_head_hidden_size * config.text_config.num_attention_heads // config.text_config.num_key_value_heads

        llm_kv_size = llm_head_hidden_size

        num_hidden_layers = config.text_config.num_hidden_layers

        num_experts = config.text_config.num_experts if hasattr(config.text_config, 'num_experts') else 0

        deepstack_visual_indexes = config.vision_config.deepstack_visual_indexes if hasattr(config.vision_config, 'deepstack_visual_indexes') else []

        return create_qwen3_vl_ops(config.vision_config.hidden_size,

                                   config.vision_config.num_heads,

                                   num_key_value_heads,

                                   llm_q_size,

                                   llm_kv_size,

                                   num_hidden_layers,

                                   num_experts,

                                   deepstack_visual_indexes

                                   )



    @staticmethod

    def hf_to_mm(cfg: ConvertVppMMConfigQwen3):

        """huggingface模型转换mindspeed-mm模型权重"""

        ops = Qwen3VLMegatronConverter._create_ops(cfg.hf_config.config)

        hf_to_mm.convert_hf_to_mm(cfg, ops, qwen3_vl_tp_patterns, [vision_schema, text_schema])

        # 安全管控权限

        set_directory_permissions(cfg.mm_dir)



    @staticmethod

    def mm_to_hf(cfg: ConvertHFConfig):

        """mindspeed-mm模型转换huggingface模型权重"""

        config = cfg.hf_config.config

        ops = Qwen3VLMegatronConverter._create_ops(cfg.hf_config.config)

        mm_to_hf.convert_mm_to_hf(cfg, ops, qwen3_vl_tp_patterns)

        # 安全管控权限

        set_directory_permissions(cfg.mm_dir)



    @staticmethod

    def resplit():

        """mindspeed-mm模型权重重新切分"""

        pass