import torch
from torch import nn

from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.spec_utils import ModuleSpec

from mindspeed_mm.models.common.module import MultiModalModule


class InternVLMLP(MultiModalModule):
    def __init__(
        self,
        config: TransformerConfig,
        submodules: ModuleSpec,
    ):
        super().__init__(config=config)

        downsample_ratio = config.downsample_ratio
        vit_hidden_size = config.vit_hidden_size
        llm_hidden_size = config.llm_hidden_size

        use_norm = getattr(config, "norm", True)
        self.norm = nn.LayerNorm(vit_hidden_size * int(1 / downsample_ratio) ** 2) if use_norm else None
        self.linear_fc1 = nn.Linear(vit_hidden_size * int(1 / downsample_ratio) ** 2, llm_hidden_size)
        self.activation_func = nn.GELU()
        self.linear_fc2 = nn.Linear(llm_hidden_size, llm_hidden_size)

    def forward(
        self,
        hidden_state,
    ):
        if self.norm is not None:
            hidden_state = self.norm(hidden_state)
        hidden_state = self.linear_fc1(hidden_state)
        hidden_state = self.activation_func(hidden_state)
        hidden_state = self.linear_fc2(hidden_state)

        return hidden_state