import torch
from torch import nn
from torch.nn import LayerNorm
from transformers.activations import ACT2FN

from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.spec_utils import ModuleSpec

from mindspeed_mm.models.common.module import MultiModalModule


class GlmMLP(MultiModalModule):
    def __init__(
        self,
        config: TransformerConfig,
        submodules: ModuleSpec,
    ):
        super().__init__(config=config)

        dim = config.out_hidden_size
        context_dim = config.intermediate_size
        hidden_act = config.hidden_act
        bias = False

        self.proj = nn.Linear(dim, dim, bias=bias)
        self.post_projection_norm = LayerNorm(dim)
        self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
        self.up_proj = nn.Linear(dim, context_dim, bias=bias)
        self.down_proj = nn.Linear(context_dim, dim, bias=bias)
        self.act1 = nn.GELU()
        self.act_fn = ACT2FN[hidden_act]

    def forward(
        self,
        hidden_state,
    ):
        hidden_state = self.proj(hidden_state)
        hidden_state = self.act1(self.post_projection_norm(hidden_state))
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))