import torch
from megatron.core.transformer.module import MegatronModule
from bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from bridge.models.conversion.model_bridge import MegatronModelBridge
from bridge.models.conversion.param_mapping import (
    AutoMapping,
    GatedMLPMapping,
    QKVMapping,
    ReplicatedMapping,
    WeightReshapeMapping,
)
from mindspeed_mm.models.vlm_model import VLMModel


class Qwen3VLForConditionalGeneration():
    pass


@MegatronModelBridge.register_bridge(source=Qwen3VLForConditionalGeneration, target=VLMModel)
class Qwen3VLBridge(MegatronModelBridge):

    def mapping_registry(self) -> MegatronMappingRegistry:
        mapping_list = []

        mapping_list.extend(
            [
                ReplicatedMapping(
                    megatron_param="model.lm_head.weight",
                    hf_param="lm_head.weight",
                ),
                ReplicatedMapping(
                    megatron_param="model.model.**",
                    hf_param="model.**",
                ),
            ]
        )

        return MegatronMappingRegistry(*mapping_list)


class Qwen3VLMoeForConditionalGeneration():
    pass


@MegatronModelBridge.register_bridge(source=Qwen3VLMoeForConditionalGeneration, target=VLMModel)
class Qwen3VLMoEBridge(MegatronModelBridge):

    def mapping_registry(self) -> MegatronMappingRegistry:
        mapping_list = []

        mapping_list.extend(
            [
                ReplicatedMapping(
                    megatron_param="model.lm_head.weight",
                    hf_param="lm_head.weight",
                ),
                WeightReshapeMapping(
                    megatron_param="model.model.language_model.layers.*.mlp.experts.down_proj",
                    hf_param="model.language_model.layers.*.mlp.experts.down_proj",
                ),
                WeightReshapeMapping(
                    megatron_param="model.model.language_model.layers.*.mlp.experts.gate_up_proj",
                    hf_param="model.language_model.layers.*.mlp.experts.gate_up_proj",
                ),
                ReplicatedMapping(
                    megatron_param="model.model.**",
                    hf_param="model.**",
                ),
            ]
        )

        return MegatronMappingRegistry(*mapping_list)