import torch
from megatron.core.transformer.module import MegatronModule
from bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from bridge.models.conversion.model_bridge import MegatronModelBridge
from bridge.models.conversion.param_mapping import AutoMapping
from mindspeed_mm.models.sora_model import SoRAModel
class WanTransformer3DModel():
pass
@MegatronModelBridge.register_bridge(source=WanTransformer3DModel, target=SoRAModel)
class Wan21Bridge(MegatronModelBridge):
"""
Megatron Bridge for Qwen2.5-VL Conditional Generation.
This bridge handles the conversion between HuggingFace Qwen2_5_VLForConditionalGeneration
and Megatron-Core GPTModel formats, including weight mappings and
configuration translation for vision-language models.
Example:
>>> from bridge import AutoBridge
>>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
>>> provider = bridge.to_megatron_provider()
"""
def mapping_registry(self) -> MegatronMappingRegistry:
param_mappings = {
"time_embedding.0.bias": "condition_embedder.time_embedder.linear_1.bias",
"time_embedding.0.weight": "condition_embedder.time_embedder.linear_1.weight",
"time_embedding.2.bias": "condition_embedder.time_embedder.linear_2.bias",
"time_embedding.2.weight": "condition_embedder.time_embedder.linear_2.weight",
"time_projection.1.bias": "condition_embedder.time_proj.bias",
"time_projection.1.weight": "condition_embedder.time_proj.weight",
"text_embedding.linear_1.weight": "condition_embedder.text_embedder.linear_1.weight",
"text_embedding.linear_1.bias": "condition_embedder.text_embedder.linear_1.bias",
"text_embedding.linear_2.bias": "condition_embedder.text_embedder.linear_2.bias",
"text_embedding.linear_2.weight": "condition_embedder.text_embedder.linear_2.weight",
"patch_embedding.weight": "patch_embedding.weight",
"patch_embedding.bias": "patch_embedding.bias",
"head.modulation": "scale_shift_table",
"head.head.weight": "proj_out.weight",
"head.head.bias": "proj_out.bias",
"blocks.*.modulation": "blocks.*.scale_shift_table",
"blocks.*.self_attn.proj_q.weight": "blocks.*.attn1.to_q.weight",
"blocks.*.self_attn.proj_q.bias": "blocks.*.attn1.to_q.bias",
"blocks.*.self_attn.proj_k.weight": "blocks.*.attn1.to_k.weight",
"blocks.*.self_attn.proj_k.bias": "blocks.*.attn1.to_k.bias",
"blocks.*.self_attn.proj_v.weight": "blocks.*.attn1.to_v.weight",
"blocks.*.self_attn.proj_v.bias": "blocks.*.attn1.to_v.bias",
"blocks.*.self_attn.q_norm.weight": "blocks.*.attn1.norm_q.weight",
"blocks.*.self_attn.k_norm.weight": "blocks.*.attn1.norm_k.weight",
"blocks.*.self_attn.proj_out.weight": "blocks.*.attn1.to_out.0.weight",
"blocks.*.self_attn.proj_out.bias": "blocks.*.attn1.to_out.0.bias",
"blocks.*.norm3.weight": "blocks.*.norm2.weight",
"blocks.*.norm3.bias": "blocks.*.norm2.bias",
"blocks.*.cross_attn.proj_q.weight": "blocks.*.attn2.to_q.weight",
"blocks.*.cross_attn.proj_q.bias": "blocks.*.attn2.to_q.bias",
"blocks.*.cross_attn.proj_k.weight": "blocks.*.attn2.to_k.weight",
"blocks.*.cross_attn.proj_k.bias": "blocks.*.attn2.to_k.bias",
"blocks.*.cross_attn.proj_v.weight": "blocks.*.attn2.to_v.weight",
"blocks.*.cross_attn.proj_v.bias": "blocks.*.attn2.to_v.bias",
"blocks.*.cross_attn.q_norm.weight": "blocks.*.attn2.norm_q.weight",
"blocks.*.cross_attn.k_norm.weight": "blocks.*.attn2.norm_k.weight",
"blocks.*.cross_attn.proj_out.weight": "blocks.*.attn2.to_out.0.weight",
"blocks.*.cross_attn.proj_out.bias": "blocks.*.attn2.to_out.0.bias",
"blocks.*.ffn.0.weight": "blocks.*.ffn.net.0.proj.weight",
"blocks.*.ffn.0.bias": "blocks.*.ffn.net.0.proj.bias",
"blocks.*.ffn.2.weight": "blocks.*.ffn.net.2.weight",
"blocks.*.ffn.2.bias": "blocks.*.ffn.net.2.bias",
}
mapping_list = []
for megatron_param, hf_param in param_mappings.items():
mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param))
return MegatronMappingRegistry(*mapping_list)