# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Ernie VL model"""
import re
import math
import itertools
from dataclasses import dataclass
from collections import defaultdict
from copy import deepcopy
from functools import partial
from typing import List, Optional, Tuple, Union

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel

from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_ernie4_5_vl import (
    DFNRopeVisionTransformerConfig,
    Ernie4_5_MoEConfig,
    Ernie4_5_VLMoEConfig,
)

logger = logging.get_logger(__name__)


__all__ = [
    "Ernie4_5_VLMoeForConditionalGeneration",
    "DFNRopeVisionTransformerPreTrainedModel",
    "VariableResolutionResamplerModel",
]


class TokenType:
    """token type definition"""

    text = 0
    image = 1
    video = 2


class UniqueNameGuard:
    """name guard"""

    def __init__(self, prefix=""):
        self.prefix = prefix
        self.counter = {}

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def get_unique_name(self, name):
        """get unique name"""
        if name not in self.counter:
            self.counter[name] = 0
        else:
            self.counter[name] += 1
        return f"{self.prefix}{name}_{self.counter[name]}"


class RopeEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) implementation for transformer models.

    RoPE encodes absolute positional information with rotation matrices and
    naturally incorporates relative position information in self-attention.

    Args:
        head_dim (int): Dimension size of each attention head
        compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0.
        base (int, optional): Base value for frequency calculation. Defaults to 10000.

    Attributes:
        head_dim (int): Dimension size of each attention head
        compression_ratio (float): Sequence length compression factor
        base (int): Base value for frequency calculation
    """

    def __init__(self, head_dim, compression_ratio=1.0, base=10000, freq_allocation=0):
        """
        Initialize RoPE embedding layer.

        Args:
            head_dim: Dimension of each attention head
            compression_ratio: Scaling factor for position indices
            base: Base value for frequency calculation
        """
        super().__init__()
        self.head_dim = head_dim
        self.compression_ratio = compression_ratio
        self.base = base

        # num of freq allocated to time
        self.freq_allocation = freq_allocation

    def forward(self, seq_length, position_ids=None):
        """
        Compute rotary position embeddings for given sequence length.

        Args:
            seq_length (int): Maximum sequence length
            position_ids (Tensor, optional): Custom position indices. Defaults to None.

        Returns:
            Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim]
        """
        indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32)
        indices = 1 / self.base ** (indices / self.head_dim)
        if position_ids is None:
            position_ids = torch.arange(
                0, seq_length, 1, dtype=torch.float32
            ).unsqueeze(1)
            position_ids = position_ids / self.compression_ratio
            sinusoid_inp = position_ids * indices.unsqueeze(0)
        else:
            position_ids = position_ids / self.compression_ratio
            seq_length = position_ids.shape[-1]
            sinusoid_inp = position_ids.unsqueeze(-1).to(
                torch.float32
            ) * indices.unsqueeze(0)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim)
        pos_emb = pos_emb.detach()
        return pos_emb

    def apply_rotary(self, rp, q, k):
        """
        Apply rotary position embeddings to queries and keys.

        Args:
            rp (Tensor): Rotary position embeddings
            q (Tensor): Query tensor [batch, heads, seq_len, dim]
            k (Tensor): Key tensor [batch, heads, seq_len, dim]

        Returns:
            Tuple[Tensor, Tensor]: Rotated queries and keys
        """
        sin, cos = torch.chunk(rp, 2, dim=-1)
        # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
        sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape)
        # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
        cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape)
        # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
        rotate_half_q = torch.stack(
            [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1
        ).reshape(q.shape)
        query = (q.to(torch.float32) * cos_pos) + (
            rotate_half_q.to(torch.float32) * sin_pos
        )
        # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
        rotate_half_k = torch.stack(
            [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1
        ).reshape(k.shape)
        key = (k.to(torch.float32) * cos_pos) + (
            rotate_half_k.to(torch.float32) * sin_pos
        )
        return query, key

    def apply_rotary_3d(self, rp, q, k, position_ids):
        """
        rope 3d rotary

        args:
            rp: [1, max_seqlen, 1, head_dim]
            q: [bsz, seqlen, head, head_dim]
            k: [bsz, seqlen, head, head_dim]
            position_ids: [bsz, seqlen, 3]
        """
        current_device = q.device
        sin, cos = torch.chunk(rp, 2, axis=-1)
        assert position_ids.shape[:1] == q.shape[:1]
        batch_indices = torch.arange(end=position_ids.shape[0])
        batch_indices = batch_indices[..., None]
        sin = sin.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device)
        cos = cos.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device)

        assert self.freq_allocation != 0
        sin_t = sin[batch_indices, position_ids[..., 0], :, -self.freq_allocation :]
        sin_h = sin[
            batch_indices,
            position_ids[..., 1],
            :,
            : self.head_dim // 2 - self.freq_allocation : 2,
        ]
        sin_w = sin[
            batch_indices,
            position_ids[..., 2],
            :,
            1 : self.head_dim // 2 - self.freq_allocation : 2,
        ]
        sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape(
            sin_h.shape[:-1] + (sin_h.shape[-1] * 2,)
        )
        sin_thw = torch.cat([sin_hw, sin_t], dim=-1)

        cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :]
        cos_h = cos[
            batch_indices,
            position_ids[..., 1],
            :,
            : self.head_dim // 2 - self.freq_allocation : 2,
        ]
        cos_w = cos[
            batch_indices,
            position_ids[..., 2],
            :,
            1 : self.head_dim // 2 - self.freq_allocation : 2,
        ]
        cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape(
            cos_h.shape[:-1] + (cos_h.shape[-1] * 2,)
        )
        cos_thw = torch.cat([cos_hw, cos_t], dim=-1)

        # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
        sin_pos = (
            torch.stack([sin_thw, sin_thw], dim=-1)
            .reshape(sin_thw.shape[:3] + (sin_thw.shape[-1] * 2,))
            .to(current_device)
        )
        # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
        cos_pos = (
            torch.stack([cos_thw, cos_thw], dim=-1)
            .reshape(cos_thw.shape[:3] + (cos_thw.shape[-1] * 2,))
            .to(current_device)
        )

        # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
        rotate_half_q = torch.stack(
            [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1
        ).reshape(q.shape)
        query = (q.to(torch.float32) * cos_pos) + (
            rotate_half_q.to(torch.float32) * sin_pos
        )
        # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
        rotate_half_k = torch.stack(
            [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1
        ).reshape(k.shape)
        key = (k.to(torch.float32) * cos_pos) + (
            rotate_half_k.to(torch.float32) * sin_pos
        )
        return query, key


class Ernie4_5_MLP(nn.Module):
    """
    Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
    """

    def __init__(self, config, layer_idx=0):
        """
        Initialize the MLP module with configuration options.

        Args:
            config (Ernie4_5_Config): Model configurations.
            layer_idx (int): Index of current layer (default: 0)
        """
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=config.use_bias
        )
        self.up_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=config.use_bias
        )
        self.down_proj = nn.Linear(
            self.intermediate_size, self.hidden_size, bias=config.use_bias
        )

    def forward(self, x):
        """
        Forward pass through the MLP module.

        Args:
            x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]

        Returns:
            Tensor: Output tensor of shape [batch_size, seq_len, hidden_size]
        """
        current_device = self.gate_proj.weight.data.device
        x = x.to(current_device)
        down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class Ernie4_5_Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config, layer_idx=0):
        """Initialize the attention layer.

        Args:
            config (Ernie4_5_Config): Model configuration.
            layer_idx (int, optional): Index in transformer stack. Defaults to 0.
        """
        super().__init__()
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.is_gqa = (
            self.num_key_value_heads is not None
            and self.num_key_value_heads != self.num_heads
        )

        self.freq_allocation = getattr(config, "freq_allocation", 0)
        assert (
            self.freq_allocation is not None
        ), "freq_allocation must be provided if rope_3d is on."

        if config.tensor_parallel_degree > 1:
            assert (
                self.num_heads % config.tensor_parallel_degree == 0
            ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
            self.num_heads = self.num_heads // config.tensor_parallel_degree
            if self.is_gqa:
                assert (
                    self.num_key_value_heads % config.tensor_parallel_degree == 0
                ), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}"
                self.num_key_value_heads = (
                    self.num_key_value_heads // config.tensor_parallel_degree
                )
        q_hidden_size = self.head_dim * self.num_heads
        if self.is_gqa:
            logger.info(
                f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}"
            )
            assert (
                self.num_heads % self.num_key_value_heads == 0
            ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}"
            kv_hidden_size = self.head_dim * self.num_key_value_heads
        else:
            kv_hidden_size = self.head_dim * self.num_heads

        self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias)
        self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)
        self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)

        self.o_proj = nn.Linear(
            self.hidden_size,
            self.hidden_size,
            bias=config.use_bias,
        )

        self.rotary_emb = RopeEmbedding(
            self.head_dim,
            compression_ratio=config.compression_ratio,
            base=config.rope_theta,
            freq_allocation=self.freq_allocation,
        )
        self.config = config
        if self.config.use_flash_attention:
            self.attn_func = self._flash_attention_wrapper
        else:
            self.attn_func = self.core_attn

    def forward(
        self,
        hidden_states,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attn_mask_start_row_indices: Optional[torch.Tensor] = None,
        position_ids: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        token_type_ids: Optional[Tuple[torch.Tensor]] = None,  # MLLM
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Compute attention outputs.

        Args:
            hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size]
            past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states
            attention_mask (Optional[torch.Tensor]): Attention mask tensor
            attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
            position_ids (Optional[torch.Tensor]): Position indices for RoPE
            output_attentions (bool): Return attention weights if True
            use_cache (bool): Cache key/value states if True

        Returns:
            Tuple containing:
                - attention_output: [bsz, seq_len, hidden_size]
                - attention_weights: Optional attention probabilities
                - updated_key_value_cache: Optional updated cache
        """
        if token_type_ids is not None:
            token_type_ids = token_type_ids[:, :-1]

        bsz, q_len, _ = hidden_states.shape
        query_states = self.q_proj(hidden_states).reshape(
            [bsz, q_len, -1, self.head_dim]
        )
        key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim])
        value_states = self.v_proj(hidden_states).reshape(
            [bsz, q_len, -1, self.head_dim]
        )

        attn_output, attn_weights, past_key_value = self.rope_attn(
            query_states=query_states,
            key_states=key_states,
            value_states=value_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            past_key_value=past_key_value,
            use_cache=use_cache,
            attn_mask_start_row_indices=attn_mask_start_row_indices,
        )
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

    def repeat_kv(self, hidden_states, n_rep):
        """
        This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
        num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(
            batch, num_key_value_heads, n_rep, slen, head_dim
        )
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

    def _flash_attention_wrapper(
        self,
        q,
        k,
        v,
        attention_mask=None,
        attn_mask_start_row_indices=None,
        seq_length=None,
    ):
        """Wrapper for flash attention implementation.
        Args:
            q (torch.Tensor): Query tensor
            k (torch.Tensor): Key tensor
            v (torch.Tensor): Value tensor
            attention_mask (Optional[torch.Tensor]): Attention mask
            attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
            seq_length (Optional[int]): Sequence length
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
        """
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
            out = F.scaled_dot_product_attention(
                q,
                k,
                v,
                attn_mask=None,
                dropout_p=self.config.attention_probs_dropout_prob,
                is_causal=q.shape[-2] == k.shape[-2],
                scale=1
                / (getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5),
                enable_gqa=self.is_gqa,
            )
        out = out.transpose(1, 2)
        out = out.contiguous().view(out.size(0), out.size(1), -1)

        return out, None

    def core_attn(
        self,
        q,
        k,
        v,
        attention_mask=None,
        attn_mask_start_row_indices=None,
        seq_length=None,
    ):
        """Standard self-attention implementation.

        Args:
            q (torch.Tensor): Query tensor
            k (torch.Tensor): Key tensor
            v (torch.Tensor): Value tensor
            attention_mask (Optional[torch.Tensor]): Attention mask
            attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
            seq_length (Optional[int]): Sequence length

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
        """
        origin_dtype = q.dtype

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        scale_qk_coeff = getattr(self.config, "scale_qk_coeff", 1.0) * (
            self.head_dim**0.5
        )

        q = q / scale_qk_coeff

        # Handle GQA case - repeat k and v heads to match q heads
        if self.is_gqa:
            # [batch, num_key_value_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
            repeat_factor = self.num_heads // self.num_key_value_heads
            k = self.repeat_kv(k, repeat_factor)
            v = self.repeat_kv(v, repeat_factor)

        product = torch.matmul(q, k.transpose(-2, -1))

        product = product.to(torch.float32)
        if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0:
            product = product * getattr(self.config, "scale_qk_coeff", 1.0)

        seq_len = product.size(-1)
        mask = torch.triu(
            torch.ones((seq_len, seq_len), dtype=torch.bool, device=product.device),
            diagonal=1,
        )
        product = product.masked_fill(mask, float("-inf"))
        weights = F.softmax(product, dim=-1)

        weights = weights.to(origin_dtype)

        if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0:
            weights = F.dropout(
                weights,
                self.config.attention_probs_dropout_prob,
                training=self.training,
            )

        out = torch.matmul(weights, v)

        # combine heads
        out = out.permute(0, 2, 1, 3)
        out = out.contiguous().view(out.size(0), out.size(1), -1)

        return out, weights

    def rope_attn(
        self,
        query_states,
        key_states,
        value_states,
        attention_mask,
        position_ids,
        output_attentions=False,
        past_key_value=None,
        use_cache=False,
        attn_mask_start_row_indices=None,
    ):
        """Attention computation with rotary embeddings.

        Args:
            mix_layer (Optional[torch.Tensor]): Combined QKV projection
            query_states (torch.Tensor): Query states
            key_states (torch.Tensor): Key states
            value_states (torch.Tensor): Value states
            attention_mask (Optional[torch.Tensor]): Attention mask
            position_ids (Optional[torch.Tensor]): Position indices
            output_attentions (bool): Return attention weights
            past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states
            use_cache (bool): Cache new states
            attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices

        Returns:
            Tuple containing:
                - attention_output: Result tensor
                - attention_weights: Optional weights
                - updated_key_value_cache: Optional cache
        """

        query_states_dtype = query_states.dtype

        assert position_ids is not None, "rope3d requires pos-id"
        kv_seq_len = position_ids.max() + 1
        offset = 0
        if past_key_value is not None:
            offset = position_ids.max()
            kv_seq_len = position_ids.max() + 1
            position_ids = position_ids[:, -1:, :]

        cos_sin = self.rotary_emb(kv_seq_len).permute([0, 2, 1, 3])
        if offset > 0 and position_ids is None:
            cos_sin = cos_sin[:, offset:]
        query_states, key_states = self.rotary_emb.apply_rotary_3d(
            cos_sin, query_states, key_states, position_ids
        )

        query_states = query_states.to(query_states_dtype)
        key_states = key_states.to(query_states_dtype)
        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=1)
            value_states = torch.cat([past_key_value[1], value_states], dim=1)

        # shape: [2, b, s, kvh, d]
        past_key_value = [key_states, value_states] if use_cache else None
        seq_length = query_states.shape[1]
        attn_output, attn_weights = self.attn_func(
            query_states,
            key_states,
            value_states,
            attention_mask,
            attn_mask_start_row_indices,
            seq_length,
        )

        return attn_output, attn_weights, past_key_value


class FusedDropoutImpl(nn.Module):
    """
    Fused dropout implementation with residual connection support.

    This layer combines dropout and residual addition in a single operation for better performance,
    particularly on GPU devices. The dropout is conditionally applied based on the probability.

    Args:
        prob (float): Dropout probability (between 0 and 1)
        mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer'

    Attributes:
        prob (float): Stores the dropout probability
        mode (str): Stores the dropout mode
        dropout (nn.Dropout): The actual dropout layer instance
    """

    def __init__(self, prob, mode):
        """
        Initialize the fused dropout layer.

        Args:
            prob (float): Dropout probability (0 means no dropout)
            mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer')
        """
        super().__init__()
        self.prob = prob
        self.dropout = nn.Dropout(p=prob)

    def forward(self, x, y):
        """
        Forward pass of the fused dropout layer.

        Args:
            x (Tensor): Input tensor to potentially apply dropout on
            y (Tensor): Residual tensor to add to the (possibly dropped out) x

        Returns:
            Tensor: Result of x (with optional dropout) + y
        """
        if self.prob > 0:
            x = self.dropout(x)
        output = x + y

        return output


class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization (RMSNorm) implementation.

    RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
    omitting the mean-centering operation. This provides computational efficiency while maintaining
    good performance.

    """

    def __init__(self, config):
        """
        Initialize RMSNorm layer.

        Args:
            config (Ernie4_5_Config): Model configuration.
        """
        super().__init__()
        self.hidden_size = config.hidden_size
        self.weight = nn.Parameter(
            torch.ones(self.hidden_size, dtype=torch.get_default_dtype())
        )
        self.variance_epsilon = config.rms_norm_eps

    def forward(self, hidden_states):
        """
        Apply RMS normalization to input hidden states.

        Args:
            hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]

        Returns:
            Tensor: Normalized output tensor of same shape as input

        Note:
            - computes RMSNorm manually:
                1. Compute variance of features
                2. Apply reciprocal square root normalization
                3. Scale by learned weight parameter
            - Maintains original dtype for numerical stability during computation
        """
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states
        return hidden_states.to(self.weight.dtype) * self.weight


class Ernie4_5_MoeMLP(Ernie4_5_MLP):
    """Mixture of Experts (MoE) variant of ERNIE's MLP layer."""

    def __init__(self, config, layer_idx=0):
        """Initialize the MoE MLP layer.

        Args:
            config (Ernie4_5_MoEConfig): Configuration for MoE architecture.
            layer_idx (int): Index of current layer in transformer stack
        """

        if getattr(config, "disable_ffn_model_parallel", False):
            config = deepcopy(config)
            config.tensor_parallel_degree = 1

        super().__init__(config, layer_idx=layer_idx)
        self.moe_dropout_prob = config.moe_dropout_prob

    def forward(self, x):
        """Forward pass through MoE MLP layer.

        Args:
            x (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
                              or [seq_len, hidden_size]

        Returns:
            paddle.Tensor: Output tensor with same shape as input
        """
        current_device = self.gate_proj.weight.data.device
        x = x.to(current_device)
        x = F.silu(self.gate_proj(x)) * self.up_proj(x)
        if self.moe_dropout_prob > 0:
            x = F.dropout(input=x, p=self.moe_dropout_prob)
        ret = self.down_proj(x)
        return ret


def masked_fill(x, mask, value):
    """
    Fills elements of the input tensor with a given value where mask is True.
    """
    return torch.where(mask, torch.full_like(x, value), x)


def _squared_l2_norm(x: torch.Tensor) -> torch.Tensor:
    """Computes 0.5 * sum(x^2)"""
    return 0.5 * torch.sum(x * x)


@torch.no_grad()
def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10):
    """
    Computes optimal transport matrix and Sinkhorn distance using Sinkhorn-Knopp algorithm.
    """
    n, _ = M.shape
    P = F.softmax(-M / lam, dim=1)  # Applying softmax over columns
    u = torch.zeros(n, dtype=torch.float32, device=M.device)

    for _ in range(max_iters):
        P_sum_1 = P.sum(1)
        if (u - P_sum_1).abs().max() < epsilon:
            break
        u = P_sum_1
        P *= (r / (u + 1e-8)).unsqueeze(1)
        P *= (c / (P.sum(0) + 1e-8)).unsqueeze(0)

    P = torch.where(~P.isnan(), P, torch.zeros_like(P))
    return P, _


class Top2Gate(nn.Module):
    """
    Gate module implementing Top2Gating as described in Gshard paper.
    """

    def __init__(self, config, layer_idx: int, group=None, gate_weight=None) -> None:
        """
        Initialize the MoE (Mixture of Experts) layer.

        Args:
            config: Model configuration containing MoE parameters
            layer_idx: Index of this layer in the model
            group: Distributed communication group
            gate_weight: Optional pre-existing gate weight tensor
        """
        super().__init__()
        self.config = config

        self.model_dim = config.hidden_size
        self.num_experts = config.moe_num_experts
        self.num_experts_tensor = (
            sum(config.moe_num_experts)
            if config.multimodel_experts
            else config.moe_num_experts
        )

        self.cap = config.moe_capacity
        self.group = group

        self.layer_idx = layer_idx

        self.sinkhorn_2gate = config.sinkhorn_2gate
        self.sinkhorn_temp = config.sinkhorn_temp
        self.use_correction_bias = config.moe_use_aux_free  # true
        self.use_token_type_bias = config.get("moe_use_token_type_bias", False)

        self.act = partial(F.softmax, dim=-1)  # [S,E]

        self.no_jitter = True
        self.expert_drop = False
        self.eye_matrix = None
        self.eye_matrix_size = None
        self.norm_gate_logits = config.moe_norm_gate_logits  # true
        self.one = torch.ones([], dtype=torch.float32)

        self.moe_aux_loss_lambda = torch.tensor(config.moe_aux_loss_lambda).to(
            dtype=torch.float32
        )
        self.moe_z_loss_lambda = torch.tensor(config.moe_z_loss_lambda).to(
            dtype=torch.float32
        )
        self.moe_orthogonal_loss_lambda = torch.tensor(
            config.moe_orthogonal_loss_lambda
        ).to(dtype=torch.float32)

        if self.moe_aux_loss_lambda.ndim == 0:
            self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0)
        if self.moe_z_loss_lambda.ndim == 0:
            self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0)
        if self.moe_orthogonal_loss_lambda.ndim == 0:
            self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze(
                0
            )

        self.experts_type_ids = None

        self.eps = torch.tensor([1e-12]).to(dtype=torch.float32)
        if config.multimodel_experts:
            if config.get("moe_use_hard_gate", False):
                self.num_experts_list = []
                self.experts_type_mask = []
                # hard-gate + group_experts 需要对gate_logits不同部分分开计算
                experts_ids = torch.zeros(
                    [sum(self.num_experts)], dtype=torch.int64
                ).reshape((1, -1))
                offset = 0
                for i, expert_num in enumerate(self.num_experts):
                    experts_ids[:, offset : offset + expert_num] = i
                    offset += expert_num
                self.experts_type_ids = experts_ids.reshape([-1])
                logger.info(
                    f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}"
                )
                for i, expert_num in enumerate(self.num_experts):
                    self.experts_type_mask.append(
                        self.experts_type_ids == i,
                    )
                    self.num_experts_list.append(expert_num)
            else:
                # 非group_experts, 依赖token_type_bias实现hard-gate能力。
                assert (
                    not config.moe_group_experts
                ), "group_experts must use hard_gate when multimodel_experts is True"
        else:
            self.num_experts_list = [self.num_experts]

        if gate_weight is not None:
            self.weight = gate_weight

            assert (
                not self.config.moe_use_token_type_bias
            ), "gate_weights is from outside, token_type_bias can't be used"
            logger.info("moe use gate_weight from outside")
            # use fp32 pecison in amp
            self._cast_to_low_precision = False
            self._cast_to_low_precison = False
        else:
            self._create_gate_parameter()
        logger.info(
            f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} "
            f"use_token_type_bias:{self.use_token_type_bias} "
            f"gate_act:{config.moe_gate_act} "
            f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}"
        )

    def _create_gate_parameter(self):
        """
        Create gate weight parameter.
        """
        if self.config.multimodel_experts:
            # support setting lambda for each expert group
            self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand(
                len(self.num_experts)
            )
            self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand(
                len(self.num_experts)
            )
            self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand(
                len(self.num_experts)
            )

            for i, num_experts in enumerate(self.num_experts):
                if i == 1:
                    with UniqueNameGuard(f"mm_gate_{self.layer_idx}_"):
                        p = nn.Parameter(
                            torch.empty(
                                self.model_dim,
                                num_experts,
                                dtype=torch.float32,
                                device="cpu",
                            )
                        )
                        nn.init.xavier_uniform_(p)  # Common initialization
                else:
                    p = nn.Parameter(
                        torch.empty(
                            self.model_dim,
                            num_experts,
                            dtype=torch.float32,
                            device="cpu",
                        )
                    )
                    nn.init.xavier_uniform_(p)  # Common initialization
                self.register_parameter(
                    "weight" if i == 0 else f"weight_{i}",
                    p,
                )
        else:
            self.weight = nn.Parameter(
                torch.empty(self.model_dim, self.num_experts, dtype=torch.float32)
            )
            nn.init.xavier_uniform_(self.weight)  # Common initialization
        # use fp32 pecison in amp
        self._cast_to_low_precision = False
        self._cast_to_low_precison = False

    def get_gate_weight(self, transform_weight, is_multimodel=True):
        """
        在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体
        transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠
        """
        if not is_multimodel or not self.config.multimodel_experts:
            return self.weight
        else:
            return torch.cat(
                [
                    getattr(self, "weight" if i == 0 else f"weight_{i}")
                    for i in range(len(self.num_experts))
                ],
                -1,
            )

    def forward(
        self,
        input: torch.Tensor,
        token_type_ids: torch.Tensor = None,
        transform_weight: bool = True,
        correction_bias: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass through the gate.

        Args:
            input: Input tensor of shape [Seq, Dim]
            token_type_ids: Token type IDs tensor of shape [Seq]
            transform_weight: Whether to transform weights for multimodal experts
            correction_bias: Bias tensor for correction

        Returns:
            tuple: (capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits)
        """
        orig_dtype = input.dtype
        current_device = input.device
        weight = self.get_gate_weight(transform_weight)

        logits = F.linear(
            input.to(dtype=torch.float32, device=current_device),
            weight.T.to(dtype=torch.float32, device=current_device),
        )

        (
            capacity,
            dispatch_mask,
            combine_weights,
            scatter_index,
            l_aux,
            l_zloss,
        ) = self.top2_gating(
            logits,
            correction_bias=(
                correction_bias.to(device=current_device)
                if correction_bias is not None
                else None
            ),
        )

        combine_weights = combine_weights.to(orig_dtype)
        return capacity, dispatch_mask, combine_weights, scatter_index, None, logits

    def get_capacity(self, num_tokens, cap_factor=None, is_multimodel=True):
        """
        Calculate capacity based on number of tokens.

        Args:
            num_tokens: Number of input tokens
            cap_factor: Optional capacity factor override

        Returns:
            int: Calculated capacity
        """
        if is_multimodel and self.config.multimodel_experts:
            num_experts = sum(self.num_experts_list)
        elif isinstance(self.num_experts, (list, tuple)):
            num_experts = self.num_experts[0]
        else:
            num_experts = self.num_experts
        if cap_factor is not None:
            cap = cap_factor
        else:
            if self.training:
                cap = self.cap[0]
            elif num_tokens < num_experts:  # seqlen < num_expert
                cap = self.cap[2]
            else:
                cap = self.cap[1]
        # capacity = 2S/E
        capacity = int(cap * num_tokens // num_experts)
        assert (
            capacity > 0
        ), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}"
        return capacity

    def top2_gating(self, logits, cap=None, correction_bias=None):
        """
        Implement Top2 gating mechanism.

        Args:
            logits: Input logits tensor
            cap: Optional capacity override
            correction_bias: Bias tensor for correction

        Returns:
            tuple: (capacity, dispatch_masks, combine_weights, scatter_indexes, loss_aux, loss_z)

        Note:
        capacity: The maximum number that each token can be dispatched.
        dispatch_masks: Masks used for dispatching. The first element is the mask for the first
        type of tokens; the second element is the mask for the second type of tokens.
        combine_weights: Weights used for combining. The first element is the weight for the first
        type of tokens; the second element is the weight for the second type of tokens.
        scatter_indexes: Indexes used for scattering. The first element is the index for the first
        type of tokens; the second element is the index for the second type of tokens.
        loss_aux: Auxiliary loss.
        loss_z: Z loss.
        """
        gates = self.act(logits)

        # gates has shape of SE
        assert logits.ndim == 2, logits.shape
        num_tokens = gates.shape[0]
        num_experts = gates.shape[1]
        # capacity = 2S/E
        capacity = self.get_capacity(logits.shape[0], cap)
        current_device = logits.device

        # Create a mask for 1st's expert per token
        score_for_argmax = (
            gates + correction_bias.unsqueeze(0)
            if correction_bias is not None
            else gates
        )
        indices1_s = torch.argmax(score_for_argmax, dim=1)
        mask1 = F.one_hot(indices1_s, num_classes=num_experts).to(
            dtype=torch.int64, device=current_device
        )  # [0,1]

        # Create a mask for 2nd's expert per token using Gumbel-max trick
        # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
        if self.training and not self.no_jitter:
            gumbels = (
                -torch.empty_like(
                    logits,
                    device=current_device,
                )
                .exponential_()
                .log()
            )  # ~Gumbel(0,1)
            logits_w_noise = logits + gumbels
        else:
            logits_w_noise = logits

        logits_except1 = masked_fill(
            logits_w_noise,
            mask1.to(dtype=torch.bool, device=current_device),
            float("-inf"),
        )
        score_for_argmax = (
            self.act(logits_except1) + correction_bias.unsqueeze(0)
            if correction_bias is not None
            else logits_except1
        )
        indices2_s_original = torch.argmax(score_for_argmax, dim=1)

        if self.training and self.sinkhorn_2gate:
            r = (
                torch.ones(num_tokens, dtype=torch.float32, device=current_device)
                / num_tokens
            )
            c_mask_sum = mask1.to(dtype=torch.float32, device=current_device).sum(0)
            c = capacity - c_mask_sum
            c = torch.maximum(c, torch.zeros_like(c, device=current_device))
            c_sum = c.sum()
            if c_sum > 0:
                c = c / c_sum
            else:  # Avoid division by zero if all experts are full from top-1
                c = torch.ones_like(c, device=current_device) / num_experts

            pi, _ = compute_optimal_transport(
                -logits_except1.to(dtype=torch.float32, device=current_device).detach(),
                r,
                c,
                lam=self.sinkhorn_temp,
            )
            pi = masked_fill(
                pi, mask1.to(dtype=torch.bool, device=current_device), float("-inf")
            )
            indices2_s = torch.argmax(pi, dim=1)
        else:
            indices2_s = indices2_s_original

        mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).to(
            dtype=torch.int64, device=current_device
        )

        # Compute locations in capacity buffer
        locations1 = (
            torch.cumsum(mask1, dim=0) - 1
        )  # [0,1,1,0,1,0,0] -> [0,0,0,0,1,1,1,]
        locations2 = torch.cumsum(mask2, dim=0) - 1
        # Update 2nd's location by accounting for locations of 1st
        locations2 += torch.sum(mask1, dim=0, keepdim=True)

        # Remove locations outside capacity from mask
        mask1 = mask1 * (locations1 < capacity).to(
            dtype=torch.int64, device=current_device
        )  # [0,1,1,0,0,0,0]
        mask2 = mask2 * (locations2 < capacity).to(
            dtype=torch.int64, device=current_device
        )

        # Store the capacity location for each token
        locations1_s = torch.sum(locations1 * mask1, dim=1)
        locations2_s = torch.sum(locations2 * mask2, dim=1)

        # Normalize gate probabilities
        mask1_float = mask1.to(dtype=torch.float32, device=current_device)
        mask2_float = mask2.to(dtype=torch.float32, device=current_device)
        gates1_s = (gates * mask1_float).sum(dim=-1)
        gates2_s = (gates * mask2_float).sum(dim=-1)
        # logger.info(f'gates1_s:{gates1_s} gates2_s:{gates2_s} logits:{logits}')

        if self.norm_gate_logits:
            denom_s = gates1_s + gates2_s  # [0.2, 0.3]
            # Avoid divide-by-zero
            denom_s = torch.clamp(denom_s, min=1e-6)
            gates1_s /= denom_s
            gates2_s /= denom_s
        if self.training and self.expert_drop:
            # log.debug(gates2_s)
            gates2_s = torch.where(
                2 * gates2_s < torch.rand_like(gates2_s, device=current_device),
                torch.zeros_like(gates2_s, device=current_device),
                gates2_s,
            )

        # Calculate combine_weights and dispatch_mask
        gates1 = gates1_s.unsqueeze(1) * mask1_float
        gates2 = gates2_s.unsqueeze(1) * mask2_float

        combine1_weight, expert1_index = torch.max(gates1, dim=-1, keepdim=True)
        scatter1_index = expert1_index.squeeze(-1) * capacity + locations1_s
        scatter1_index = scatter1_index.to(dtype=torch.int64, device=current_device)
        dispatch1_mask = combine1_weight.to(
            dtype=torch.bool, device=current_device
        ).detach()

        combine2_weight, expert2_index = torch.max(gates2, dim=-1, keepdim=True)
        scatter2_index = expert2_index.squeeze(-1) * capacity + locations2_s
        scatter2_index = scatter2_index.to(dtype=torch.int64, device=current_device)
        dispatch2_mask = combine2_weight.to(
            dtype=torch.bool, device=current_device
        ).detach()
        # logger.info(f'expert-id: {expert1_index} vs {expert2_index}, mask:{mask1_float} vs {mask2_float}')

        return (
            capacity,
            torch.cat((dispatch1_mask, dispatch2_mask), 1),
            torch.cat((combine1_weight, combine2_weight), 1),
            torch.stack((scatter1_index, scatter2_index), 1),
            None,
            None,
        )

    def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group):
        """
        Calculate optimized orthogonal loss for each weight.

        Args:
            weight: Weight tensor
            use_group: Whether to use expert groups

        Returns:
            Tensor: Calculated orthogonal loss
        """
        if weight.dtype != torch.float32:
            weight = weight.to(torch.float32)

        wnorm = torch.norm(weight, p=2, dim=1)
        weight = weight / torch.maximum(wnorm, self.eps.to(weight.device)).unsqueeze(1)

        if use_group:
            weight = weight.reshape(
                [self.config.moe_k, -1, weight.shape[1]]
            )  # [K, E/K, H]
            eye_matrix = torch.eye(
                weight.shape[1], dtype=weight.dtype, device=weight.device
            ).unsqueeze(0)
        else:
            eye_matrix = torch.eye(
                weight.shape[0], dtype=weight.dtype, device=weight.device
            )

        weight_matmul = torch.matmul(weight, weight.T)

        orthogonal_loss = weight_matmul - eye_matrix
        orthogonal_loss = _squared_l2_norm(orthogonal_loss) / (
            orthogonal_loss.size(0) * orthogonal_loss.size(1)
        )
        return orthogonal_loss


class TopKGate(Top2Gate):
    """
    Fused version of TopK gate for improved performance.
    """

    def forward(
        self,
        input: torch.Tensor,
        token_type_ids=None,
        transform_weight=True,
        is_multimodel=True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass for fused gate.

        Args:
            input: Input tensor
            token_type_ids: Token type IDs
            transform_weight: Whether to transform weights

        Returns:
            tuple: (logits, capacity, router_loss)
        """
        current_device = input.device
        weight = self.get_gate_weight(transform_weight, is_multimodel=is_multimodel)

        logits = F.linear(
            input.to(dtype=torch.float32, device=current_device),
            weight.T.to(dtype=torch.float32, device=current_device),
        )
        if self.use_token_type_bias:
            assert token_type_ids is not None
            assert (
                token_type_ids.max() < self.bias.shape[0]
            ), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}"
            bias = self.bias[token_type_ids]  # [seq]
            logits = logits + bias

        return logits


gate_class = dict(
    top2=Top2Gate,
    topk=TopKGate,
)


def get_gate(
    config: Ernie4_5_MoEConfig,
    expert: nn.Module,
    layer_idx: int,
) -> Tuple[nn.Module, nn.ModuleList]:
    """Initialize and distribute MoE (Mixture of Experts) components.

    Creates gate layer and distributed expert network for MoE architecture.

    Args:
        config (Ernie4_5_MoEConfig): Configuration for MoE architecture
        expert (nn.Module): Prototype expert network to be replicated
        layer_idx (int): Index of current layer in transformer stack

    Returns:
        Tuple[nn.Module, nn.ModuleList]:
            - gate: Initialized gate layer for routing
            - experts: ModuleList containing expert networks
    """
    moe_num_experts = (
        sum(config.moe_num_experts)
        if config.multimodel_experts
        else config.moe_num_experts
    )
    experts = nn.ModuleList([])

    for expert_id, (experts_num, fc) in enumerate(expert):
        experts_to_append = []
        if not hasattr(fc, "__len__"):  # run this
            experts_to_append.append(fc)
            if expert_id == 1:
                with UniqueNameGuard("_mm_deepcopy"):
                    for _ in range(experts_num - 1):
                        experts_to_append.append(deepcopy(fc))
            else:
                for _ in range(experts_num - 1):
                    experts_to_append.append(deepcopy(fc))
        else:
            experts_to_append = fc
        for ex in experts_to_append:
            for p in ex.parameters():
                p.expert_type = f"expert_type_{expert_id}"  # Different `expert_type` can have different intermediate-size
        index = 0
        for i in range(experts_num):
            if i // experts_num == 0:
                experts.append(experts_to_append[index])
                index += 1
            else:
                experts.append(None)

    assert (
        len(experts) == moe_num_experts
    ), f"experts.len={len(experts)} != experts_num={experts_num}"
    logger.info(f"MOE-GATE:-{config.moe_gate}")

    gate = gate_class[config.moe_gate.lower()](config, layer_idx=layer_idx)

    if config.multimodel_experts and config.moe_use_hard_gate and moe_num_experts > 2:
        lm_experts = experts[: config.moe_num_experts[0]]
        lm_gate = gate
    else:
        if config.multimodel_experts and config.moe_use_hard_gate:
            lm_gate, lm_experts = gate, experts
        else:
            lm_gate, lm_experts = None, None

    logger.info(f"LM-experts-{lm_experts} -- experts-{experts}")

    return gate, experts, lm_gate, lm_experts


class MoEStatics(nn.Module):
    """
    Stores MoE (Mixture of Experts) statistics
    and expert usage information.
    """

    def __init__(self, config, layer_idx):
        """
        Initialize MoE statistics tracking.

        Args:
            config: Model configuration containing MoE parameters
            layer_idx: Index of the MoE layer in the model
        """
        super().__init__()
        self._cast_to_low_precision = False
        self._cast_to_low_precison = False
        num_experts = (
            config.moe_num_experts[0]
            if config.multimodel_experts
            else config.moe_num_experts
        )
        if config.multimodel_experts:
            assert (
                len(set(config.moe_num_experts)) == 1
            ), "assume expert group has same size, got: {config.moe_num_experts}"

        with UniqueNameGuard(f"mm_layer_{layer_idx}_"):
            num_experts_groups = (
                len(config.moe_num_experts) if config.multimodel_experts else 1
            )
            p = nn.Parameter(
                torch.zeros(num_experts_groups, num_experts, dtype=torch.float32),
                requires_grad=False,
            )
            self.e_score_correction_bias = p
            p = torch.zeros(num_experts_groups, num_experts, dtype=torch.int64)
            self.expert_usage = p


def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity):
    """
    Reorders input tensor based on gate results with capacity truncation and padding.

    Args:
        x (Tensor): Input tensor of shape [Seq, Dim]
        dispatch_mask (Tensor): Dispatching mask of shape [Seq, 2]
        scatter_index (Tensor): Scatter indices of shape [Seq, 2]
        num_experts (int): Number of experts
        capacity (int): Capacity per expert

    Returns:
        Tensor: Dispatched output tensor of shape [Expert*Capacity, Dim]
    """
    output = None
    orig_dtype = x.dtype
    scatter_index_unbound = [scatter_index[:, 0], scatter_index[:, 1]]
    dispatch_mask_unbound = [dispatch_mask[:, 0], dispatch_mask[:, 1]]

    for i_scatter_index, i_dispatch_mask in zip(
        scatter_index_unbound, dispatch_mask_unbound
    ):
        updates = x * i_dispatch_mask.unsqueeze(-1).to(orig_dtype)  # [seq, dim]
        init_output = torch.zeros(
            num_experts * capacity, x.shape[-1], dtype=orig_dtype, device=x.device
        )

        index = i_scatter_index.unsqueeze(-1).expand(-1, x.shape[-1])  # [seq, dim]
        if output is None:
            output = init_output.scatter_add(0, index, updates)
        else:
            output = output + init_output.scatter_add(0, index, updates)
    if output.dtype != orig_dtype:
        output = output.to(orig_dtype)
    return output


def combining(x, combine_weights, scatter_index):
    """
    Combines and aggregates input matrix using combination weights.

    Args:
        x (Tensor): Input tensor of shape [num_experts * capacity, dim]
        combine_weights (Tensor): Combination weights of shape [seq, 2]
        scatter_index (Tensor): Scatter indices of shape [seq, 2]

    Returns:
        Tensor: Combined output tensor of shape [seq, dim]
    """
    dim = x.shape[-1]

    current_device = scatter_index.device
    x = x.to(current_device)
    scatter_index = scatter_index.reshape([-1])
    num_k = combine_weights.shape[-1]

    combine_weights = combine_weights.unsqueeze(1).to(current_device)

    x = x[scatter_index].reshape([-1, num_k, dim])  # [seq, 2, dim]

    return torch.matmul(combine_weights, x).squeeze(
        1
    )  # [seq, 1, 2] @ [seq, 2, dim] -> [seq, 1, dim]


class MOELayer(nn.Module):
    """
    Mixture of Experts layer implementation based on GShard paper.
    """

    def __init__(
        self,
        gate: nn.Module,
        experts: List[nn.Module],
        layer_idx: int,
        shared_experts: Optional[List[nn.Module]] = None,
        group=None,
        recompute: bool = False,
        k: int = 2,
        all_to_all_dropout: float = 0,
        group_experts: bool = False,
        moe_statics=None,
        moe_num_experts=None,
    ):
        """
        Initialize MoE layer.

        Args:
            gate: Gate network for expert selection
            experts: List of expert networks
            layer_idx: Index of this layer in the model
            group: Distributed communication group
            recompute: Whether to enable recomputation
            k: Number of experts to select per token
            all_to_all_dropout: Dropout rate for all-to-all communication
            group_experts: Whether to group experts
            moe_statics: MoE statistics tracking object
        """
        super().__init__()
        self.gate = gate
        self.layer_idx = layer_idx

        if isinstance(experts, nn.ModuleList):
            self.experts = experts
        else:
            logger.info(f"using fused experts, type={type(experts)}")
            self.experts = experts
        self.shared_experts = shared_experts

        self.group = group
        self.k = k
        self.all_to_all_dropout = all_to_all_dropout
        self.use_correction_bias = moe_statics is not None
        self.moe_statics = moe_statics
        if self.use_correction_bias:
            logger.info(
                f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}"
            )
            assert self.gate.config.moe_use_aux_free

        self.world_size = 1
        self.rank = 0

        self.multimodal_experts = (
            isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1
        )
        self.num_local_experts = len(self.experts) // self.world_size
        if self.multimodal_experts:
            self.num_local_multimodal_experts = [
                num // self.world_size for num in moe_num_experts
            ]
            self.multimodal_expert_index = [0] + list(
                itertools.accumulate(moe_num_experts)
            )

        self.input_preprocess = self.output_postprocess = None
        self.group_experts = group_experts
        self.config = self.gate.config
        self.zero = torch.tensor(0).to(dtype=torch.float32)

    def forward_experts(self, dispatched_input):
        """
        Forward pass through experts sequentially.

        Args:
            dispatched_input: Input tensor of shape [num_experts, capacity, dim]

        Returns:
            Tensor: Expert outputs of shape [num_experts, capacity, dim]
        """

        if not self.multimodal_experts:
            true_experts = self.experts[
                self.rank
                * self.num_local_experts : (self.rank + 1)
                * self.num_local_experts
            ]
        else:
            true_experts = []
            for i, num in enumerate(self.num_local_multimodal_experts):
                current_modal_experts = self.experts[
                    self.multimodal_expert_index[i] : self.multimodal_expert_index[
                        i + 1
                    ]
                ]
                true_experts.extend(
                    current_modal_experts[self.rank * num : (self.rank + 1) * num]
                )

        dispatched_input = dispatched_input.reshape(
            [self.world_size, self.num_local_experts, -1, dispatched_input.shape[-1]]
        )
        current_device = dispatched_input.device
        expert_outputs = []
        if isinstance(self.experts, nn.ModuleList):
            chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0)
            assert len(chunks) == len(
                true_experts
            ), f"{len(chunks)}, {len(true_experts)}"
            for chunk, expert in zip(chunks, true_experts):
                expert_outputs.append(expert(chunk))
        else:
            dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous()
            orig_shape = dispatched_input.shape
            chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1])
            chunks = self.experts(chunks)
            chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0)
            expert_outputs.extend(chunks)

        for i, expert_output in enumerate(expert_outputs):
            expert_outputs[i] = expert_output.to(current_device)
        expert_output = torch.stack(expert_outputs, dim=1)
        return expert_output

    def moe_gate_dispatch(
        self,
        x: torch.Tensor,  # [S, H]   float16 / float32 / bfloat16
        gate_logits: torch.Tensor,  # [S, E]   float32
        k: int,
        capacity: Optional[int],
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """dispatch input to experts based on gate logits"""

        S, H = x.shape
        E = gate_logits.shape[1]
        device = x.device
        if self.use_correction_bias:
            _, topk_idx = torch.topk(gate_logits + self.moe_statics.e_score_correction_bias[0].detach().to(gate_logits.device), k, dim=-1)     
            topk_prob = torch.gather(gate_logits, dim=1, index=topk_idx) #  [Seq, k]
        else:
            topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)  # [S, k]
        combine_weights = topk_prob  # [S, k]
        expert_id = topk_idx  # [S, k]
        y = x.new_zeros((E, capacity, H))  # [E, C, H]
        scatter_index = x.new_full((k, S), -1, dtype=torch.int32)  # [k, S]
        # per-expert slot counters
        slot_counter = torch.zeros(E, dtype=torch.int32, device=device)

        for tok in range(S):
            for route in range(k):
                e = expert_id[tok, route].item()
                slot = slot_counter[e].item()
                if slot >= capacity:  # expert is full -> drop
                    combine_weights[tok, route] = 0.0
                    continue
                # record mapping & dispatch activation
                scatter_index[route, tok] = e * capacity + slot
                y[e, slot] = x[tok]
                slot_counter[e] += 1

        expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64)

        return y, combine_weights, scatter_index, expert_offset, expert_id

    def gate_and_dispatch(self, input, token_type_ids=None, is_multimodel=True):
        """
        Calculate gate and dispatch inputs.

        Args:
            input: Input tensor of shape [seq, dim]

        Returns:
            tuple: (dispatched_input, combine_weights, dispatch_mask,
            scatter_index, router_loss, gate_logits, gate_prob)
        """
        d_model = input.shape[1]
        if isinstance(self.gate, (TopKGate)):
            capacity = self.gate.get_capacity(
                input.shape[0], is_multimodel=is_multimodel
            )
            if token_type_ids is not None:
                token_type_ids = token_type_ids.reshape([-1])
            gate_logits = self.gate(
                input, token_type_ids=token_type_ids, is_multimodel=is_multimodel
            )
            prob = self.gate.act(gate_logits)
            (
                dispatched_input,
                combine_weights_unnorm,
                scatter_index,
                dispatch_mask,
                _,
            ) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity)
            dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0)))

            scatter_index.detach()
            dispatch_mask.detach()

            scatter_index = scatter_index.transpose(0, 1)  # [k, s] -> [s, k]
            combine_weights = combine_weights_unnorm / torch.clamp(
                combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12
            )
            combine_weights = combine_weights.to(dtype=dispatched_input.dtype)

        else:
            (
                capacity,
                dispatch_mask,
                combine_weights,
                scatter_index,
                router_loss,
                gate_logits,
            ) = self.gate(
                input,
            )
            prob = None
            dispatched_input = dispatching(
                input,
                dispatch_mask,
                scatter_index,
                num_experts=self.world_size * self.num_local_experts,
                capacity=capacity,
            )

        dispatched_input = dispatched_input.reshape(
            [self.world_size * self.num_local_experts, capacity, d_model]
        )

        dispatch_mask = dispatch_mask.detach()
        scatter_index = scatter_index.detach()
        return (
            dispatched_input,
            combine_weights,
            dispatch_mask,
            scatter_index,
            None,
            gate_logits,
            prob,
        )

    def combine_expert_output(self, expert_output, combine_weights, scatter_index):
        """
        Combine expert outputs using combination weights.

        Args:
            expert_output: Expert outputs [num_experts, capacity, dim]
            combine_weights: Combination weights
            scatter_index: Scatter indices

        Returns:
            Tensor: Combined output [seqlen, dim]
        """
        expert_output = expert_output.reshape(
            [-1, expert_output.shape[-1]]
        )  # [e*1,c,m]

        combined_output = combining(expert_output, combine_weights, scatter_index)

        if self.output_postprocess is not None:
            combined_output = self.output_postprocess(combined_output)

        return combined_output

    def forward(
        self,
        input: torch.Tensor,
        token_type_ids=None,
        is_multimodel=True,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass through MoE layer.

        Args:
            input: Input tensor of shape [s, d]

        Returns:
            tuple: (output, combine_weights, router_loss, gate_logits)
        """
        if input.dim() == 3:
            orig_shape = input.shape
            input = input.reshape([-1, input.shape[-1]])
        else:
            orig_shape = None
        assert (
            input.dim() == 2
        ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
        if token_type_ids is not None:
            token_type_ids = token_type_ids.clone()[:, :-1]

        assert self.gate is not None

        gate_input = input

        (
            dispatched_input,
            combine_weights,
            dispatch_mask,
            scatter_index,
            router_loss,
            gate_logits,
            gate_prob,
        ) = self.gate_and_dispatch(
            gate_input, token_type_ids, is_multimodel=is_multimodel
        )

        if self.shared_experts is not None:
            shared_out = self.shared_experts(input)

        expert_out = self.forward_experts(dispatched_input)

        combined_output = self.combine_expert_output(
            expert_out, combine_weights, scatter_index
        )

        if self.shared_experts is not None:
            combined_output += shared_out

        if orig_shape:
            combined_output = combined_output.clone().reshape(
                orig_shape[:-1] + (combined_output.shape[-1],)
            )
        return combined_output, combine_weights, None, gate_logits


class MOEAllGatherLayerV2(MOELayer):
    """
    MoE Layer with allgather implement.
    """

    def __init__(
        self,
        gate: nn.Module,
        experts: List[nn.Module],
        layer_idx,
        shared_experts: Optional[List[nn.Module]] = None,
        group=None,
        recompute=False,
        k=2,
        enable_reverse_token_drop=False,
        all_to_all_dropout=0,
        group_experts=False,
        use_expert_out_alltoall=True,
        use_expert_alltoall_overlap=False,
        use_padding=True,
        dense_token_type=3,  # considerd as dense tokens (no moe)
        moe_statics=None,
        moe_num_experts=None,
    ):
        super().__init__(
            gate,
            experts,
            layer_idx,
            shared_experts,
            group,
            recompute,
            k,
            all_to_all_dropout,
            group_experts,
            moe_statics,
            moe_num_experts,
        )
        self.enable_reverse_token_drop = enable_reverse_token_drop
        self.is_allgather_moe_layer = True
        self.use_padding = use_padding

        self.send_rank = None
        self.local_expert_id = None
        self.dense_experts = None
        self.dense_token_type = dense_token_type
        self.capacity_tensor = None
        logger.info(
            f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, "  # false
            f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} "  # true false
            f"enable_reverse_token_drop={self.enable_reverse_token_drop}"  # false
        )
        self.two = torch.tensor(2).to(dtype=torch.float32)
        self.zero = torch.tensor(0).to(dtype=torch.float32)

    def forward(
        self,
        input: torch.Tensor,
        token_type_ids=None,
        use_dense_expert=False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication.

        Core Functionality:
          - Processes input through gating network to determine expert assignments
          - Combines expert outputs and calculates routing loss

        Key Features:
          1. Supports both dense and sparse expert computation modes
          2. Implements fused gating and dispatch for performance optimization
          3. Handles sequence length padding/unpadding for irregular inputs
          4. Enables communication-computation overlap through asynchronous operations

        Args:
            input (Tensor): Input tensor of shape [seq_len, hidden_dim]
            token_type_ids: Optional segmentation markers for heterogeneous inputs
            use_dense_expert: Flag to enable dense expert computation bypass

        Returns:
            tuple: (
                combined_output: Aggregated expert outputs [seq_len, hidden_dim],
                combine_weights: Expert combination coefficients,
            )
        """
        use_fuse = isinstance(self.gate, (TopKGate))
        assert use_fuse
        if input.ndim == 3:
            orig_shape = input.shape
            input = input.reshape([-1, input.shape[-1]])
        else:
            orig_shape = None

        assert (
            len(input.shape) == 2
        ), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
        dispatch_token_type_ids = None
        global_dense_expert_mask = None
        if token_type_ids is not None:
            token_type_ids = token_type_ids[:, :-1].reshape([-1])
            dispatch_token_type_ids = token_type_ids
            if use_dense_expert:
                global_dense_expert_mask = (
                    dispatch_token_type_ids == self.dense_token_type
                )

        assert self.gate is not None

        (
            dispatched_input,
            global_hidden_states,
            local_combine_weights,
            expert_num_global_no_token_drop,
            expert_num_global,
            expert_num_global_list,
            local_scatter_index,
            scatter_index_rev,
            router_loss,
            (gate_logits, gate_prob),
            (gate_logits_mm, gate_prob_mm),
            expert_num_local,
        ) = self.fused_gate_and_dispatch(
            input, token_type_ids, global_dense_expert_mask
        )

        seqlen_this_mp = input.shape[0]
        if len(scatter_index_rev):
            recv_rank_local = scatter_index_rev // seqlen_this_mp
        else:
            recv_rank_local = scatter_index_rev

        if self.send_rank is None:
            capacity = self.gate.get_capacity(input.shape[0])
            self.send_rank = (
                torch.arange(1)
                .repeat_interleave(capacity * self.num_local_experts)
                .to(torch.int32)  # cap
            )
            self.local_expert_id = (
                torch.arange(self.num_local_experts)
                .repeat_interleave(capacity)
                .repeat(1)
                .to(self.send_rank.dtype)
            )
        send_rank = self.send_rank
        local_expert_id = self.local_expert_id

        expert_outs = self.forward_experts(*dispatched_input)
        for e in expert_outs:
            if e is not None:
                current_device = e.device
                break
        expert_outs = torch.cat(
            [e.to(current_device) for e in expert_outs if e is not None], dim=0
        )  # [e*c,m]

        # global -> local
        combined_output = self.combine_expert_output(
            expert_outs, local_combine_weights, local_scatter_index
        )

        if self.shared_experts is not None:
            shared_out = self.shared_experts(input).to(combined_output.device)
            combined_output += shared_out

        if orig_shape:
            combined_output = combined_output.reshape(
                *orig_shape[:-1], combined_output.shape[-1]
            )

        return combined_output, local_combine_weights, None, gate_logits

    def _expand_modality_expert_id(
        self,
        expert_id: torch.Tensor,  # (seqlen, k)
        seqlen: int,
        k: int,
        num_expert_per_modality: int,
        group_size: int,
        modality_offset: int,
        is_group_expert: bool,
    ) -> torch.Tensor:
        """
        expert_id: tensor of shape (seqlen, k), containing expert ids
        Returns: tensor of same shape, with updated expert ids
        """
        device = expert_id.device
        expert_id = expert_id.clone()

        if is_group_expert:
            # idx % k * group_size
            offsets = (torch.arange(k, device=device) * group_size).view(
                1, k
            )  # shape (1, k)
            expert_id += offsets

        if num_expert_per_modality <= 0:
            return expert_id

        # Compute rank and local expert id
        rank = expert_id // num_expert_per_modality
        expert_id_in_rank = expert_id % num_expert_per_modality

        # Compute new expert id with modality-aware adjustment
        expert_id_out = (
            rank * (num_expert_per_modality * 2)  # 2 modalities assumed
            + expert_id_in_rank
            + modality_offset * num_expert_per_modality
        )

        return expert_id_out

    def expand_modality_expert_id(
        self,
        expert_id,
        num_expert_per_modality,
        group_size,
        modality_offset,
        is_group_expert,
    ):
        """expand expert id for modality aware moe layer"""
        seq_len, k = expert_id.shape

        return self._expand_modality_expert_id(
            expert_id,
            seq_len,
            k,
            num_expert_per_modality,
            group_size,
            modality_offset,
            is_group_expert,
        )

    def fused_gate_logits_process_fused(
        self, gate_logits_lm, gate_logits_mm=None, token_type_ids=None
    ):
        """Process gating logits for expert selection in Mixture-of-Experts (MoE) layers.

        Core Functionality:
        - Transforms raw gating logits into expert selection weights and IDs
        - Supports both grouped and standard expert selection modes
        - Handles bias correction for improved expert load balancing

        Args:
            gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts]

        Returns:
            tuple: (
                lm_weight_and_expert_id: Combined tensor containing selection weights
                       and expert IDs [batch_size, 2*top_k],
                prob_flat: Flattened expert probabilities [batch_size, total_experts]
            )
        """
        top_k = self.k
        num_expert_per_rank_per_modality = gate_logits_lm.shape[-1]
        group_size = gate_logits_lm.shape[-1] // top_k
        if self.group_experts:
            assert not self.use_correction_bias
            gate_logits_lm = gate_logits_lm.reshape(
                [gate_logits_lm.shape[0], top_k, -1]
            )
            prob_lm = self.gate.act(gate_logits_lm)
            prob_lm_ = prob_lm
            weight_lm, expert_id_lm = prob_lm_.topk(k=1, dim=-1)
            weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1])
            group_size = gate_logits_lm.shape[-1]
            expert_id_lm = expert_id_lm.squeeze(-1)
        else:
            prob_lm = self.gate.act(gate_logits_lm)
            if self.use_correction_bias:
                prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[
                    0
                ].detach().to(prob_lm.device)
            else:
                prob_lm_ = prob_lm
            weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, dim=-1)

        if self.use_correction_bias:
            batch_idx = (
                torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm)
            )
            weight_lm = prob_lm[batch_idx, expert_id_lm]  # use correct bias

        expert_id_lm = self.expand_modality_expert_id(
            expert_id_lm,
            num_expert_per_modality=(
                num_expert_per_rank_per_modality if token_type_ids is not None else 0
            ),
            group_size=group_size,
            modality_offset=0,
            is_group_expert=self.group_experts,
        )
        expert_id_lm = expert_id_lm.reshape(weight_lm.shape)
        lm_weight_and_expert_id = torch.cat(
            [weight_lm, expert_id_lm.to(torch.float32)], -1
        )

        if token_type_ids is None or gate_logits_mm is None:
            return (
                lm_weight_and_expert_id,
                prob_lm.reshape([prob_lm.shape[0], -1]),
                None,
            )

        prob_mm = self.gate.act(gate_logits_mm)
        if self.use_correction_bias:
            prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[
                1
            ].detach().to(prob_lm.device)
        else:
            prob_mm_ = prob_mm
        weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, dim=-1)
        if self.use_correction_bias:
            batch_idx = (
                torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm)
            )
            weight_mm = prob_mm[batch_idx, expert_id_mm]  # use correct bias

        expert_id_mm = self.expand_modality_expert_id(
            expert_id_mm,
            num_expert_per_modality=num_expert_per_rank_per_modality,
            group_size=group_size,
            modality_offset=1,
            is_group_expert=False,
        )
        expert_id_mm = expert_id_mm.reshape(weight_mm.shape)
        mm_weight_and_expert_id = torch.cat(
            [weight_mm, expert_id_mm.to(torch.float32)], -1
        )
        weight_and_expert = torch.where(
            (token_type_ids == 0).unsqueeze(-1),
            lm_weight_and_expert_id.to(token_type_ids.device),
            mm_weight_and_expert_id.to(token_type_ids.device),
        )
        return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm

    def moe_gate_dispatch_partial_nosoftmaxtopk(
        self,
        x,
        combine_weights,
        expert_id,
        k,
        num_experts,
    ):
        """
        MoE Gate Dispatch kernel
        """
        device = x.device
        dtype = x.dtype
        num_rows, hidden_size = x.shape
        k = expert_id.shape[1]
        expert_ids_flat = expert_id.reshape(-1)  # [num_rows * k]
        combine_weights_flat = combine_weights.reshape(-1)  # [num_rows * k]

        expanded_token_ids = torch.arange(num_rows * k, device=device)  # [num_rows * k]

        sorted_expert_ids, sorted_indices = torch.sort(expert_ids_flat, stable=True)
        sorted_indices = sorted_indices.to(expanded_token_ids.device)

        sorted_expanded_token_ids = expanded_token_ids[sorted_indices]

        expert_nums_local = torch.zeros(num_experts, dtype=torch.int64, device=device)

        for expert_idx in range(num_experts):
            count = (sorted_expert_ids == expert_idx).sum().item()
            expert_nums_local[expert_idx] = count

        total_dispatched_tokens = torch.cumsum(expert_nums_local, dim=0)[-1].item()

        y = x[sorted_indices // k]  # [total_dispatched_tokens, hidden_size]

        scatter_index = torch.full((k, num_rows), -1, dtype=torch.int32, device=device)

        for i, (expanded_idx, sorted_pos) in enumerate(
            zip(sorted_expanded_token_ids, range(total_dispatched_tokens))
        ):
            token_idx = expanded_idx // k
            k_idx = expanded_idx % k
            scatter_index[k_idx, token_idx] = sorted_pos

        scatter_index_rev = sorted_indices // k

        combine_weights_out = combine_weights.clone()

        return (
            y,  # [total_dispatched_tokens, hidden_size]
            combine_weights_out,  # [num_rows, k]
            scatter_index,  # [k, num_rows]
            scatter_index_rev,  # [total_dispatched_tokens]
            expert_nums_local,  # [num_experts]
            expert_nums_local,  # [num_experts]
        )

    def fused_gate_and_dispatch(
        self, input, token_type_ids=None, global_dense_expert_mask=None
    ):
        """Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers.

        Core Functionality:
          - Computes expert selection probabilities and routing weights
          - Performs distributed token-to-expert assignment
          - Handles communication and synchronization in model-parallel environments

        Args:
            input (Tensor): Input tensor of shape [seq_len, hidden_dim]

        Returns:
            tuple: (
                dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim],
                global_hidden_states: Full sequence representations,
                local_combine_weights: Local expert combination weights,
                expert_num_global_notrunc: Global expert token counts (without capacity truncation),
                expert_num_global: Actual expert token counts,
                expert_num_global_list: Per-expert token counts,
                local_scatter_index: Local token reorganization indices,
                scatter_index_rev: Reverse scattering indices,
                router_loss: Calculated routing loss,
                gate_outputs: Raw gating network outputs,
                expert_num_local: Local expert utilization counts
            )
        """
        seqlen, d_model = input.shape
        args = ()
        if token_type_ids is not None:
            token_type_ids = token_type_ids.reshape([-1])
            args = (token_type_ids,)

        router_loss = torch.zeros([1], dtype=torch.float32)
        top_k = self.k

        def build_weights_and_expert_id(input):
            nonlocal token_type_ids, args
            logits = self.gate(input, *args, transform_weight=False)
            if self.config.multimodel_experts:
                gate_logits_lm, gate_logits_mm = logits.chunk(2, dim=-1)
            else:
                gate_logits_lm, gate_logits_mm = logits, None

            weigth_and_expert, gate_prob_lm, gate_prob_mm = (
                self.fused_gate_logits_process_fused(
                    gate_logits_lm,
                    gate_logits_mm,
                    token_type_ids if global_dense_expert_mask is None else None,
                )
            )
            return (
                weigth_and_expert,
                gate_logits_lm,
                gate_logits_mm,
                gate_prob_lm,
                gate_prob_mm,
            )

        capacity = self.gate.get_capacity(input.shape[0]) * self.world_size
        global_hidden_states = input
        (
            combine_weights_and_expert_id,
            gate_logits_lm,
            gate_logits_mm,
            gate_prob_lm,
            gate_prob_mm,
        ) = build_weights_and_expert_id(input)

        combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk(
            2, dim=-1
        )
        expert_id = expert_id.to(torch.int32)
        num_experts = (
            sum(self.config.moe_num_experts)
            if isinstance(self.config.moe_num_experts, (tuple, list))
            else self.config.moe_num_experts
        )
        if global_dense_expert_mask is not None:
            combine_weights_unnorm[global_dense_expert_mask] = 0.0
            expert_id[global_dense_expert_mask] = num_experts
            num_experts += 1

        (
            dispatched_input,
            combine_weights_unnorm,
            scatter_index,  # input -> dispatched_input
            scatter_index_rev,  # dispatch-input -> input
            expert_num_global,
            expert_num_local,
        ) = self.moe_gate_dispatch_partial_nosoftmaxtopk(
            global_hidden_states,
            combine_weights_unnorm,
            expert_id,
            top_k,
            num_experts,
        )

        if self.use_correction_bias:
            if self.gate.config.multimodel_experts:
                # MLLM
                for i in range(len(self.moe_statics.expert_usage)):
                    self.moe_statics.expert_usage[i] += (
                        expert_num_local[self.gate.experts_type_mask[i]]
                        .detach()
                        .to(self.moe_statics.expert_usage.device)
                    )
            else:
                # LLM
                self.moe_statics.expert_usage[0] += expert_num_local.detach().to(
                    self.moe_statics.expert_usage.device
                )

        # When use unpad , `moe_ops_partial` output likes `scatter_index_rev==[]`.
        if scatter_index_rev.ndim == 0:
            assert not self.use_padding
            scatter_index_rev = torch.empty([0], dtype=scatter_index_rev.dtype)

        expert_num_global_notrunc = expert_num_global
        self.capacity_tensor = torch.tensor(capacity).to(dtype=expert_num_global.dtype)
        expert_num_global = torch.minimum(expert_num_global, self.capacity_tensor)

        if global_dense_expert_mask is not None:
            expert_num_global = expert_num_global[:-1]
            expert_num_local = expert_num_local[:-1]
            expert_num_global_notrunc = expert_num_global_notrunc[:-1]

        scatter_index = scatter_index.transpose(1, 0)  # [k,s] ->[s,k]
        scatter_index = scatter_index.to(combine_weights_unnorm.device)

        last_local_expert = 0
        expert_offset_global = expert_num_global.cumsum(-1)

        expert_num_global_list = expert_num_global
        if self.use_padding:
            offset = last_local_expert * capacity
        else:
            offset = 0
        local_combine_weights_unnorm = combine_weights_unnorm.contiguous()
        local_scatter_index = torch.where(
            combine_weights_unnorm > 0.0,
            scatter_index + offset,
            scatter_index,
        )
        if self.gate.norm_gate_logits:
            local_combine_weights = local_combine_weights_unnorm / torch.clip(
                local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12
            )
        else:
            local_combine_weights = local_combine_weights_unnorm
        local_combine_weights = local_combine_weights.to(dispatched_input.dtype)
        if self.use_padding:
            dispatched_input = dispatched_input.reshape(
                [self.num_local_experts, -1, d_model]
            )
            dispatched_input = dispatched_input.unbind(0)
        else:
            s = 0
            e = self.num_local_experts
            expert_num_local = expert_num_local.tolist()[s:e]
            expert_num_local_valid = [i for i in expert_num_local if i > 0]
            valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0]
            if expert_num_local_valid:
                dispatched_input_list = dispatched_input.split(expert_num_local_valid)
                dispatched_input = [None] * len(expert_num_local)
                for p, t in zip(valid_pos, dispatched_input_list):
                    dispatched_input[p] = t
            else:
                dispatched_input = [dispatched_input] + (
                    [None] * (len(expert_num_local) - 1)
                )

        expert_num_global_list = expert_num_global_list.tolist()

        return (
            dispatched_input,
            global_hidden_states,
            local_combine_weights,
            expert_num_global_notrunc,  # for auxloss calculation.
            expert_num_global,
            expert_num_global_list,
            local_scatter_index,
            scatter_index_rev,
            router_loss,
            (gate_logits_lm, gate_prob_lm),
            (gate_logits_mm, gate_prob_mm),
            expert_num_local,
        )

    def forward_experts(self, *dispatched_input):
        """Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer.

        Core Functionality:
          - Distributes dispatched tokens to local expert models
          - Handles empty expert inputs with zero-initialized fallback
          - Maintains gradient flow for expert outputs
          - Aggregates outputs from all active experts

        Args:
            *dispatched_input: Variable-length expert-specific input tensors

        Returns:
            list: Expert output tensors (None for inactive experts)

        Implementation Details:
          1. Processes valid expert inputs through corresponding expert models
          2. Generates dummy inputs for inactive experts to preserve model structure
          3. Aggregates dummy outputs to first active expert to maintain gradient flow
        """
        expert_outputs = []
        assert isinstance(self.experts, nn.ModuleList), type(self.experts)

        no_tokens_expert_outputs = []
        true_experts = self.experts[
            self.rank
            * self.num_local_experts : (self.rank + 1)
            * self.num_local_experts
        ]
        for iexpert, chunk in enumerate(dispatched_input):
            if chunk is None:
                expert_outputs.append(None)
                continue

            expert_out = true_experts[iexpert](chunk.contiguous())
            expert_outputs.append(expert_out)

        if len(no_tokens_expert_outputs) > 0:
            first_has_tokens_idx = 0
            for idx, expert_out in enumerate(expert_outputs):
                if expert_out is not None:
                    first_has_tokens_idx = idx
                    break
            for idx, expert_out in enumerate(no_tokens_expert_outputs):
                expert_outputs[first_has_tokens_idx] += expert_out

        return expert_outputs


class Ernie4_5_DecoderLayer(nn.Module):
    """A single transformer decoder layer in ERNIE-MoE model.

    Contains self-attention and feed-forward components with optional MoE (Mixture of Experts)
    support, residual connections, and layer normalization.
    """

    _keep_in_fp32_modules = ["mlp.gate", "e_score_correction_bias"]

    def __init__(self, config, layer_idx):
        """Initialize the decoder layer.

        Args:
            config (Ernie4_5_MoEConfig): Model configuration.
            layer_idx (int): Index of this layer in the transformer stack
        """
        super().__init__()
        self.hidden_size = config.hidden_size
        self.layer_idx = layer_idx
        self.config = config
        self.use_moe = config.use_moe
        self.self_attn = Ernie4_5_Attention(config, layer_idx)

        moe_layer_start_index = (
            min(config.moe_layer_start_index)
            if isinstance(config.moe_layer_start_index, (tuple, list))
            else config.moe_layer_start_index
        )
        moe_layer_end_index = (
            max(config.moe_layer_end_index)
            if isinstance(config.moe_layer_end_index, (tuple, list))
            else config.moe_layer_end_index
        )

        if (
            self.use_moe
            and ((layer_idx + 1) % config.moe_layer_interval == 0)
            and layer_idx >= moe_layer_start_index  # 3
            and layer_idx <= moe_layer_end_index  # 53
        ):
            gate, experts, lm_gate, lm_experts, moe_statics = (
                self._init_gate_and_experts(layer_idx)
            )
            shared_experts = (
                self._init_shared_experts()
                if hasattr(config, "moe_num_shared_experts")
                else None
            )

            dense_experts = None
            moe_cls = MOELayer
            if config.moe_multimodal_dispatch_use_allgather:  # v2
                logger.info("Enable MOEAllGatherLayerV2!")
                moe_cls = partial(
                    MOEAllGatherLayerV2,
                    use_expert_out_alltoall="alltoall"
                    in config.moe_multimodal_dispatch_use_allgather,  # false
                    use_padding=False,
                    enable_reverse_token_drop=config.moe_reverse_token_drop,  # false
                    dense_token_type=config.moe_dense_experts_token_type_id,  # 3
                )
            else:
                assert (
                    dense_experts is None
                ), "only `MOEAllGatherLayerV2` can process dense experts"

            self.mlp = moe_cls(
                gate=gate,
                experts=experts,
                layer_idx=layer_idx,
                shared_experts=shared_experts,
                group=config.moe_group,
                recompute=False,
                k=config.moe_k,
                all_to_all_dropout=config.moe_all_to_all_dropout,
                group_experts=False,
                moe_statics=moe_statics,
                moe_num_experts=config.moe_num_experts,
            )

            _mlp_text = MOELayer(
                gate=lm_gate,
                experts=lm_experts,
                layer_idx=layer_idx,
                shared_experts=shared_experts,
                group=config.moe_group,
                recompute=False,
                k=config.moe_k,
                all_to_all_dropout=config.moe_all_to_all_dropout,
                group_experts=False,
                moe_statics=moe_statics,
                moe_num_experts=config.moe_num_experts,
            )
            self.mlp_text = (
                lambda: _mlp_text
            )  # This lambda prevents the text parameter from being scanned into the state-dict
        else:
            self.mlp = Ernie4_5_MLP(config)

        Norm = RMSNorm

        self.input_layernorm = Norm(config)
        self.post_attention_layernorm = Norm(config)

        self.residual_add1 = FusedDropoutImpl(
            config.hidden_dropout_prob, mode="upscale_in_train"
        )
        self.residual_add2 = FusedDropoutImpl(
            config.hidden_dropout_prob, mode="upscale_in_train"
        )

    def _init_shared_experts(self):
        """init shared experts

        Returns:
            _type_: _description_
        """
        cfg = deepcopy(self.config)
        if cfg.moe_num_shared_experts > 0:
            if cfg.moe_intermediate_size:
                inter_size = (
                    next(iter(cfg.moe_intermediate_size))
                    if isinstance(cfg.moe_intermediate_size, (tuple, list))
                    else cfg.moe_intermediate_size
                )
                cfg.intermediate_size = inter_size * cfg.moe_num_shared_experts
            else:
                cfg.intermediate_size = (
                    cfg.intermediate_size * cfg.moe_num_shared_experts
                )
            cfg.disable_ffn_model_parallel = False  # split shared epxert
            shared_experts = Ernie4_5_MoeMLP(cfg, True)
        else:
            shared_experts = None
        return shared_experts

    def _init_gate_and_experts(self, layer_idx):
        """Initialize MoE gate and expert networks.

        Args:
            layer_idx (int): Current layer index

        Returns:
            Tuple: Contains:
                - gate: MoE routing gate
                - experts: List of expert networks
                - moe_statics: Optional statistics tracker
        """
        cfg = deepcopy(self.config)
        fc_cls = Ernie4_5_MoeMLP
        if cfg.moe_intermediate_size:
            if isinstance(cfg.moe_intermediate_size, (tuple, list)):
                assert isinstance(cfg.moe_num_experts, (tuple, list)) and len(
                    cfg.moe_num_experts
                ) == len(cfg.moe_intermediate_size)
                fc = []
                for _i, (num_experts, intermediate_size) in enumerate(
                    zip(cfg.moe_num_experts, cfg.moe_intermediate_size)
                ):
                    ex_cfg = deepcopy(cfg)
                    ex_cfg.intermediate_size = intermediate_size
                    cur_modality_start_layer_idx = (
                        cfg.moe_layer_start_index[_i]
                        if isinstance(cfg.moe_layer_start_index, (tuple, list))
                        else cfg.moe_layer_start_index
                    )
                    cur_modality_end_layer_idx = (
                        cfg.moe_layer_end_index[_i]
                        if isinstance(cfg.moe_layer_end_index, (tuple, list))
                        else cfg.moe_layer_end_index
                    )
                    if (
                        layer_idx >= cur_modality_start_layer_idx
                        and layer_idx <= cur_modality_end_layer_idx
                    ):
                        if _i == 1:
                            with UniqueNameGuard(f"mm_expert_{layer_idx}_") as guard:
                                fc.append((num_experts, fc_cls(ex_cfg)))
                        else:
                            fc.append((num_experts, fc_cls(ex_cfg)))
                    else:
                        logger.info(
                            f"moe multimodal experts use Identity layer_idx: {layer_idx}"
                        )
                        fc.append((num_experts, nn.Identity()))
            else:
                cfg.intermediate_size = cfg.moe_intermediate_size
                fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))]
        else:
            fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))]
        if cfg.multimodel_experts:
            gate, experts, lm_gate, lm_experts = get_gate(self.config, fc, layer_idx)
        else:
            gate, experts = get_gate(self.config, fc, layer_idx)
            lm_gate, lm_experts = None, None

        # for AuxLoss Free Router:
        if cfg.moe_use_aux_free:
            moe_statics = MoEStatics(cfg, layer_idx)
        else:
            moe_statics = None
        return gate, experts, lm_gate, lm_experts, moe_statics

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        attn_mask_start_row_indices: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        use_cache: Optional[bool] = False,
        output_gate_logits=True,  # PP model should not output gate logits,
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """Forward pass through the decoder layer.

        Args:
            hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
            attention_mask (Optional[torch.Tensor]): Attention mask tensor
            attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention
            position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
            output_attentions (Optional[bool]): Whether to return attention weights
            past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
            use_cache (Optional[bool]): Whether to cache key/value states
            output_gate_logits (bool): Whether to return MoE gate logits

        Returns:
            Union: Various output combinations depending on arguments:
                - Base case: Hidden states tensor
                - With attention: Tuple of (hidden_states, attention_weights)
                - With cache: Tuple of (hidden_states, cached_key_value)
                - With MoE: May include gate logits in output tuple
        """
        residual = hidden_states

        if token_type_ids is not None:
            is_multimodel_token = token_type_ids.any()
            has_dense_experts_token = (
                token_type_ids == self.config.moe_dense_experts_token_type_id
            ).any()
            is_multimodel_token_cpu = is_multimodel_token.cpu()
            has_dense_experts_token_cpu = has_dense_experts_token.cpu()
        else:
            is_multimodel_token_cpu = None
            has_dense_experts_token_cpu = None

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        (hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = (
            self.self_attn(
                hidden_states=hidden_states,
                past_key_value=past_key_value,
                attention_mask=attention_mask,
                attn_mask_start_row_indices=attn_mask_start_row_indices,
                position_ids=position_ids,
                output_attentions=output_attentions,
                use_cache=use_cache,
                token_type_ids=token_type_ids,
            )
        )
        hidden_states = self.residual_add1(hidden_states, residual)

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)

        if isinstance(self.mlp, MOELayer):
            if is_multimodel_token_cpu:
                hidden_states, _, router_loss, gate_logits = self.mlp(
                    hidden_states, token_type_ids
                )
            else:
                hidden_states, _, router_loss, gate_logits = self.mlp_text()(
                    hidden_states, None, is_multimodel=False
                )
        else:
            hidden_states = self.mlp(hidden_states)
            gate_logits, router_loss = None, None

        hidden_states = self.residual_add2(hidden_states, residual)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        if self.use_moe:
            # Non-empty only if `use_moe`
            if router_loss_attn:
                router_loss_attn = router_loss_attn[0]
                router_loss = router_loss + router_loss_attn

            if output_gate_logits:
                outputs += (gate_logits,)

        # remove empty tuple for pipeline parallel
        if type(outputs) is tuple and len(outputs) == 1:
            outputs = outputs[0]

        return outputs


class Ernie4_5_PretrainedModel(PreTrainedModel):
    """Base class for ERNIE pretrained models."""

    config_class = Ernie4_5_MoEConfig
    base_model_prefix = "ernie"
    _no_split_modules = ["Ernie4_5_DecoderLayer"]


class Ernie4_5_Model(Ernie4_5_PretrainedModel):
    """The core ERNIE transformer model with MoE (Mixture of Experts) support."""

    def __init__(self, config: Ernie4_5_MoEConfig):
        """Initialize the ERNIE model architecture.

        Args:
            config (Ernie4_5_MoEConfig): Model configuration.
        """
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.hidden_size = config.hidden_size
        self.config = config

        self.embed_tokens = nn.Embedding(
            self.vocab_size,
            self.hidden_size,
        )

        self.layers = nn.ModuleList(
            [Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
        )
        Norm = RMSNorm
        self.norm = Norm(config)

        self.gradient_checkpointing = False

    def get_input_embeddings(self):
        """Get the input embedding layer.

        Returns:
            nn.Embedding: The embedding layer for input tokens
        """
        return self.embed_tokens

    def set_input_embeddings(self, value):
        """Set new input embeddings.

        Args:
            value (nn.Embedding): New embedding layer to use
        """
        self.embed_tokens = value

    def forward(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        attention_mask=None,
        attn_mask_start_row_indices=None,
        inputs_embeds=None,
        use_cache=None,
        past_key_values=None,
        output_attentions=False,
        output_hidden_states=None,
        return_dict=False,
    ):
        """Forward pass through the ERNIE model.

        Args:
            input_ids (Optional[torch.Tensor]): Input token IDs
            position_ids (Optional[torch.Tensor]): Position indices
            attention_mask (Optional[torch.Tensor]): Attention mask
            attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
            inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings
            use_cache (Optional[bool]): Whether to cache key/value states
            past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states
            output_attentions (Optional[bool]): Whether to output attention weights
            output_hidden_states (Optional[bool]): Whether to output all hidden states
            return_dict (Optional[bool]): Whether to return dict or tuple

        Returns:
            Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
                Various outputs depending on configuration, including:
                - last_hidden_state: Final layer hidden states
                - past_key_values: Cached key/value states if use_cache=True
                - hidden_states: All hidden states if output_hidden_states=True
                - attentions: Attention weights if output_attentions=True
                - router_loss: MoE router loss if use_moe=True
                - gate_logits: MoE gate logits if use_moe=True
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
            )
        elif input_ids is not None:
            _, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            _, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError(
                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
            )

        if past_key_values is None:
            past_key_values = tuple([None] * len(self.layers))

        seq_length_with_past = seq_length
        cache_length = 0
        if past_key_values[0] is not None:
            cache_length = past_key_values[0][0].shape[1]
            seq_length_with_past += cache_length
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None
        if getattr(self.config, "use_moe", False):
            all_router_loss = torch.tensor(0.0).to(device=inputs_embeds.device)
        else:
            all_router_loss = None
        all_gate_logits = ()

        for idx, (decoder_layer) in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = (
                past_key_values[idx] if past_key_values is not None else None
            )
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask,
                attn_mask_start_row_indices,
                position_ids,
                token_type_ids,
                output_attentions,
                past_key_value,
                use_cache,
            )

            if isinstance(layer_outputs, (tuple, list)):
                hidden_states = layer_outputs[0]
            else:
                hidden_states = layer_outputs

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)
            if self.config.use_moe:
                layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
                all_gate_logits = all_gate_logits + (gate_logits,)

            if past_key_value is not None:
                hidden_states = hidden_states[:, -1:, :]

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_cache,
                    all_hidden_states,
                    all_self_attns,
                    all_router_loss,
                    all_gate_logits,
                ]
                if v is not None
            )

        # assert all_router_loss is None, f'moe not support `return-dict`'
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            cross_attentions=None,
            router_loss=all_router_loss,
            gate_logits=all_gate_logits,
        )


def parallel_matmul(
    x,
    y,
    bias=None,
    transpose_y=False,
):
    """
    Performs parallel matrix multiplication with tensor model parallelism support.

    Args:
        x (torch.Tensor): Input tensor with shape [batch_size, seq_len, hidden_size]
        y (Union[torch.Tensor, EagerParamBase]): Weight matrix which can be:
            - Regular tensor
            - Distributed parameter in tensor parallel mode
        bias (Optional[torch.Tensor]): Optional bias tensor
        transpose_y (bool): Whether to transpose the 'y' matrix before multiplication
        # tensor_parallel_degree (int): Degree of tensor model parallelism (default: 1)
        # tensor_parallel_output (bool): Whether to keep output in tensor parallel format
            or gather across devices (default: True)
        fuse_linear (bool): Whether to use fused linear operation for optimization

    Returns:
        torch.Tensor

    Raises:
        AssertionError: If tensor parallel is enabled but weight is not distributed
        AttributeError: If called without distributed.launch context
    """
    if transpose_y:
        logits = torch.matmul(x, y.T)
    else:
        logits = torch.matmul(x, y)
    if bias is not None:
        logits += bias
    return logits


def calc_lm_head_logits(
    config, hidden_states, weight, bias, tensor_parallel_output=None, training=True
):
    """
    Calculate language model head logits with support for various parallelization strategies.

    This is the core function that computes the final output logits for a language model,
    handling sequence parallelism and tensor parallelism configurations.

    Args:
        config (Ernie4_5_Config): Model configuration.
        hidden_states (Tensor): Hidden states from the transformer layers
        weight (Tensor): Weight matrix for the language model head
        bias (Tensor): Bias vector for the language model head
        tensor_parallel_output (bool, optional): Override for tensor parallel output behavior.
                                               If None, uses config.tensor_parallel_output.
                                               Defaults to None.
        training (bool, optional): Whether in training mode. Defaults to True.

    Returns:
        Tensor: The computed logits for language modeling.
    """
    if tensor_parallel_output is None:
        tensor_parallel_output = config.tensor_parallel_output
    logits = parallel_matmul(
        hidden_states,
        weight,
        bias=bias,
        transpose_y=config.tie_word_embeddings,
    )

    return logits


def calc_multimodal_logits(
    last_hidden_state: torch.Tensor,
    lm_head_weight: torch.Tensor,
    lm_head_bias: torch.Tensor,
    mm_head_weight: torch.Tensor,
    mm_head_bias: torch.Tensor,
    token_type_ids_shifted: torch.Tensor,
    config: Ernie4_5_VLMoEConfig,
):
    """
    calculate logits for pure text, multimodal text, and image
    Args:
        last_hidden_state: The hidden of the last layer, in sequence-parallel, is in the split state.
        ...
        token_type_ids_shifted: # Non-sp split tensor
            The token-type-ids at the label position is used to select the lm-head corresponding to each token.
            Note: In the id sequence of alternating images and texts, the last text token will predict the image id,
            and vice versa, so it is necessary to select the lmhead weight corresponding to the label type.
    """
    # Align the type of ids with the type of label. For the last ids, assume that the token type remains unchanged.
    # TODO: Pass token-type-ids from reader
    assert last_hidden_state.shape[:2] == token_type_ids_shifted.shape, (
        last_hidden_state.shape,
        token_type_ids_shifted.shape,
    )
    parallel_matmul_tp = partial(
        parallel_matmul,
    )

    if mm_head_weight is None:
        if config.use_recompute_loss_fn:
            return last_hidden_state, None, None
        score_text = parallel_matmul_tp(last_hidden_state, lm_head_weight, lm_head_bias)
        return score_text, None, None

    image_mask_shifted = token_type_ids_shifted == TokenType.image
    text_pos_shifted = token_type_ids_shifted == TokenType.text

    if text_pos_shifted.any().item() > 0:
        score_text = parallel_matmul_tp(
            last_hidden_state[text_pos_shifted], lm_head_weight, lm_head_bias
        )
    else:
        score_text = None

    if mm_head_weight is not None and image_mask_shifted.any().item() > 0:
        score_image = parallel_matmul_tp(
            last_hidden_state[image_mask_shifted], mm_head_weight, mm_head_bias
        )
    else:
        score_image = None

    return score_text, score_image, None


class Ernie4_5_MoeLMHead(nn.Module):
    """Language model head for ERNIE with support for tensor parallelism."""

    def __init__(self, config):
        """Initialize the language model head.

        Args:
            config (Ernie4_5_Config): Model configuration containing:
                - vocab_size: Size of vocabulary
                - hidden_size: Dimension of hidden states
                # - tensor_parallel_degree: Degree of tensor parallelism
                - tie_word_embeddings: Whether to tie input/output embeddings
                - weight_share_add_bias: Whether to add bias when weight sharing
                - use_bias: Whether to use bias term
                - use_recompute_loss_fn: Whether to defer logits computation to loss function
                - use_sparse_head_and_loss_fn: Whether to use sparse head computation
        """

        super(Ernie4_5_MoeLMHead, self).__init__()
        self.config = config
        if config.tensor_parallel_degree > 1:
            vocab_size = config.vocab_size // config.tensor_parallel_degree
        else:
            vocab_size = config.vocab_size

        if config.tie_word_embeddings:
            self.weight = nn.Parameter(
                torch.empty(
                    vocab_size, config.hidden_size, dtype=torch.get_default_dtype()
                )
            )
        else:
            self.weight = nn.Parameter(
                torch.empty(
                    config.hidden_size, vocab_size, dtype=torch.get_default_dtype()
                )
            )
        nn.init.xavier_uniform_(self.weight)

        logger.info(
            f"output-weight:{self.weight.shape} tie_word_embeddings:{config.tie_word_embeddings}"
        )

        if config.weight_share_add_bias and config.use_bias:
            self.bias = nn.Parameter(
                torch.zeros(vocab_size, dtype=torch.get_default_dtype())
            )
        else:
            self.bias = None

        # Must set distributed attr for Tensor Parallel !
        self.weight.is_distributed = (
            True if (vocab_size != config.vocab_size) else False
        )
        if config.weight_share_add_bias and config.use_bias:
            self.bias.is_distributed = (
                True if (vocab_size != config.vocab_size) else False
            )

        if self.weight.is_distributed:
            self.weight.split_axis = 1
        if (
            config.weight_share_add_bias
            and config.use_bias
            and self.bias.is_distributed
        ):
            self.bias.split_axis = 0

        if self.config.use_recompute_loss_fn:
            logger.info(
                "Using recompute_loss_fn, the calculation of logits will be moved into "
                "loss_fn for memory optimization"
            )

    def forward(self, hidden_states, tensor_parallel_output=None):
        """Project hidden states to vocabulary logits.

        Args:
            hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
            tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None.

        Returns:
            Union[
                Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
                    # When use_recompute_loss_fn or use_sparse_head_and_loss_fn
                    - hidden_states: Original input
                    - weight: Projection weights
                    - bias: Optional bias term
                Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], bool]:  # With tensor_parallel_output
                    Same as above plus tensor_parallel_output flag
                torch.Tensor:  # Normal case
                    Logits tensor of shape [batch_size, seq_len, vocab_size]
            ]
        """
        return calc_lm_head_logits(
            self.config,
            hidden_states,
            self.weight,
            self.bias,
            tensor_parallel_output,
            training=self.training,
        )


class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin):
    """ERNIE Mixture of Experts (MoE) model for causal language modeling."""

    _keys_to_ignore_on_load_missing = [r"lm_head.weight"]

    def __init__(self, config):
        """
        Initializes the ERNIE MoE model for causal language modeling.

        Args:
            config (dict): Model configuration.
        """
        super().__init__(config)

        # initialize-trick for big model,
        # see https://github.com/bigscience-workshop/bigscience/blob/master/train/tr11-176B-ml/README.md#std-init
        new_initializer_range = math.sqrt(0.3333 / config.hidden_size)
        logger.info(
            f"change initializer-range from {config.initializer_range} to {new_initializer_range}"
        )
        config.initializer_range = new_initializer_range
        self.config = config
        self.model = Ernie4_5_Model(config)
        self.lm_head = Ernie4_5_MoeLMHead(config)

        self.tie_weights()  # maybe weight share

    def get_input_embeddings(self):
        """Returns the input embeddings layer."""
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        """Sets the input embeddings layer."""
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        """Returns the output embeddings (LM head)."""
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        """Sets the output embeddings layer."""
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        """Sets the ERNIE decoder model."""
        self.model = decoder

    def get_decoder(self):
        """Get the transformer decoder.

        Returns:
            nn.Layer: The decoder module
        """
        return self.model

    # @staticmethod
    def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False):
        """
        Updates model kwargs for generation.

        Args:
            outputs (Any): Model outputs.
            model_kwargs (dict): Current model kwargs.
            is_encoder_decoder (bool): Whether using encoder-decoder architecture.

        Returns:
            dict: Updated model kwargs.
        """
        # update cache
        if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], torch.Tensor):
            model_kwargs["past_key_values"] = outputs[1]

        if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs:
            model_kwargs["past_key_values"] = outputs.past_key_values

        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs and model_kwargs["token_type_ids"] is not None:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1:]], dim=-1)

        if not is_encoder_decoder and model_kwargs.get("attention_mask", None) is not None:
            # update attention mask
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [
                    attention_mask,
                    torch.ones((attention_mask.shape[0], 1), dtype=torch.int64, device=attention_mask.device),
                ],
                dim=-1,
            )

        # update role_ids
        if "role_ids" in model_kwargs and model_kwargs["role_ids"] is not None:
            role_ids = model_kwargs["role_ids"]
            model_kwargs["role_ids"] = torch.cat([role_ids, role_ids[:, -1:]], dim=-1)

        if self.config.get('rope_3d', False):
            assert "position_ids" in model_kwargs, "position_ids must be provided if rope_3d is on"
            position_ids = model_kwargs["position_ids"]
            bsz = position_ids.shape[0]

            max_position = position_ids.max(dim=1, keepdim=True)[0]  # [batch_size, 1, hidden_dim]
            new_positions = max_position + 1
            
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_positions],
                dim=1
            )

        return model_kwargs


class VisionMlp(nn.Module):
    """VisionMLP"""

    def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = ACT2FN[hidden_act]
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): input tensor

        Returns:
            torch.Tensor: VisionMLP output tensor
        """
        return self.fc2(self.act(self.fc1(x)))


class PatchEmbed(nn.Module):
    """PatchEmbed"""

    def __init__(
        self,
        patch_size: int = 14,
        in_channels: int = 3,
        embed_dim: int = 1152,
    ) -> None:
        """
        Args:
            patch_size (int, optional): patch size. Defaults to 14.
            in_channels (int, optional): number of channels. Defaults to 3.
            embed_dim (int, optional): embedding dimension. Defaults to 1152.
        """
        super().__init__()
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.proj = nn.Linear(
            in_channels * patch_size * patch_size, embed_dim, bias=False
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden_states (torch.Tensor): hidden states

        Returns:
            torch.Tensor: output tensor
        """
        target_dtype = self.proj.weight.dtype

        hidden_states = self.proj(hidden_states.to(target_dtype))

        return hidden_states


class VisionRotaryEmbedding(nn.Module):
    """VisionRotaryEmbedding"""

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        """
        Args:
            dim (int): the dimension of each token.
            theta (float, optional): the frequency factor. Defaults to 10000.0.
        """
        super().__init__()
        self.inv_freq = 1.0 / theta ** (
            torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim
        )

    def forward(self, seqlen: int) -> torch.Tensor:
        """
        Args:
            seqlen (int): length of sequence.

        Returns:
            torch.Tensor: rotary position embedding
        """
        seq = torch.arange(seqlen).to(self.inv_freq.dtype)
        freqs = torch.outer(input=seq, vec2=self.inv_freq)
        return freqs


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)  # shape is the same as x


def apply_rotary_pos_emb_vision(
    tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
    """Applies Rotary Position Embedding to the input tensors.

    Args:
        tensor (torch.Tensor): The input tensor.
        freqs (torch.Tensor): The frequencies used for the rotation.
    Returns:
        output (torch.Tensor): the tensor rotated using the Rotary Position Embedding.
    """
    orig_dtype = tensor.dtype

    tensor = tensor.type(dtype=torch.float32)
    cos = freqs.cos()
    sin = freqs.sin()
    cos = cos.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32)
    sin = sin.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32)
    output = tensor * cos + rotate_half(tensor) * sin
    output = output.to(orig_dtype)
    return output


class VisionAttention(nn.Module):
    """VisionAttention"""

    def __init__(self, dim: int, num_heads: int = 16) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.head_dim = dim // num_heads  # must added

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """forward function for vision attention"""
        seq_length = hidden_states.shape[0]
        qkv = (
            self.qkv(hidden_states)
            .reshape([seq_length, 3, self.num_heads, -1])
            .permute(1, 0, 2, 3)
        )
        q, k, v = qkv.unbind(axis=0)

        q = apply_rotary_pos_emb_vision(q.unsqueeze(dim=0), rotary_pos_emb).squeeze(
            dim=0
        )
        k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze(
            dim=0
        )
        
        q = q.transpose(0, 1)
        k = k.transpose(0, 1)
        v = v.transpose(0, 1)
        
        lengths = cu_seqlens[1:] - cu_seqlens[:-1]
        splits = [
            torch.split(tensor, lengths.tolist(), dim=1) for tensor in (q, k, v)
        ]
        
        attn_output = []
        for q, k, v in zip(*splits):
            attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
            attn_weights = nn.functional.softmax(
                attn_weights, dim=-1, dtype=torch.float32
            ).to(q.dtype)
            attn_output_splited = torch.matmul(attn_weights, v)
            attn_output_splited = attn_output_splited.transpose(0, 1)
            attn_output.append(attn_output_splited)
        attn_output = torch.cat(attn_output, dim=0)
        attn_output = attn_output.reshape(seq_length, -1).contiguous()
        attn_output = self.proj(attn_output)
        return attn_output


class DFNRopeVisionBlock(nn.Module):
    """DFNRopeVisionBlock"""

    def __init__(self, config, attn_implementation: str = "sdpa") -> None:
        """
        Args:
            config (dict): model configuration.
            attn_implementation (str, optional): attention implementation. Defaults to "sdpa".
        """
        super().__init__()
        self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
        self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
        mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)

        self.attn = VisionAttention(config.embed_dim, num_heads=config.num_heads)
        self.mlp = VisionMlp(
            dim=config.embed_dim,
            hidden_dim=mlp_hidden_dim,
            hidden_act=config.hidden_act,
        )
        self.config = config

    def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
        """
        Args:
            hidden_states(torch.Tensor): hidden states
            cu_seqlens (torch.Tensor): cumulative sequence lengths
            rotary_pos_emb: rotary position embedding

        Returns:
            torch.Tensor: output tensor
        """
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


class DFNRopeVisionTransformerPreTrainedModel(PreTrainedModel):
    """DFNRopeVisionTransformerPreTrainedModel"""

    config_class = DFNRopeVisionTransformerConfig
    _tp_plan = {}

    def __init__(self, config) -> None:
        """
        Args:
            config (dict): model configuration
        """
        super().__init__(config)
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = PatchEmbed(
            patch_size=config.patch_size,
            in_channels=config.in_channels,
            embed_dim=config.embed_dim,
        )

        head_dim = config.embed_dim // config.num_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList(
            [DFNRopeVisionBlock(config) for _ in range(config.depth)]
        )

        assert (
            config.hidden_size == config.embed_dim
        ), "in DFNRope, vit's config.hidden must be equal to config.embed_dim"
        self.ln = nn.LayerNorm(config.hidden_size, eps=1e-6)

    def rot_pos_emb(self, grid_thw, num_pad=0):
        """rot_pos_emb

        Args:
            grid_thw (torch.Tensor): grid thw of input

        Returns:
            torch.Tensor: rotary position embedding
        """
        pos_ids = []
        grid_hw_array = np.array(grid_thw.cpu(), dtype=np.int64)
        for t, h, w in grid_hw_array:
            hpos_ids = np.arange(h).reshape([-1, 1])
            hpos_ids = np.tile(hpos_ids, (1, w))
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3))
            hpos_ids = hpos_ids.flatten()

            wpos_ids = np.arange(w).reshape([1, -1])
            wpos_ids = np.tile(wpos_ids, (h, 1))
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3))
            wpos_ids = wpos_ids.flatten()

            stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1)
            tiled_ids = np.tile(stacked_ids, (t, 1))
            pos_ids.append(tiled_ids)

        pos_ids = np.concatenate(pos_ids, axis=0)
        if num_pad > 0:
            pos_ids = np.concatenate(
                [pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)]
            )
        max_grid_size = np.amax(grid_hw_array[:, 1:])
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_dim=1)
        return rotary_pos_emb

    def forward(
        self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0
    ) -> torch.Tensor:
        """
        Args:
            hidden_states (torch.Tensor): input tensor
            grid_thw (torch.Tensor): grid thw of input
            num_pad (int): number of padding tokens

        Returns:
            torch.Tensor: output tensor
        """
        hidden_states = self.patch_embed(hidden_states)

        rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad)
        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)

        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)

        if num_pad > 0:
            cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0)
            cu_seqlens[-1] = cu_seqlens[-2] + num_pad
        else:
            cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for idx, blk in enumerate(self.blocks):
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
            )

        ret = self.ln(hidden_states)  # add norm
        return ret


class VariableResolutionResamplerModel(nn.Module):
    """
    VariableResolutionResamplerModel, support variable resolution
    """

    def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.config = config
        self.spatial_conv_size = spatial_conv_size
        self.temporal_conv_size = temporal_conv_size
        self.use_temporal_conv = config.use_temporal_conv

        # compress 2d conv(picture) to 1d
        self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size
        # compress 3d conv(video) to 1d
        self.temporal_dim = (
            self.in_dim
            * self.spatial_conv_size
            * self.spatial_conv_size
            * self.temporal_conv_size
        )

        # using unique name space start with "mm_resampler_"
        with UniqueNameGuard("mm_resampler_") as guard:

            self.spatial_linear = nn.Sequential(
                nn.Linear(self.spatial_dim, self.spatial_dim),
                nn.GELU(),
                nn.Linear(self.spatial_dim, self.spatial_dim),
                nn.LayerNorm(self.spatial_dim, eps=1e-6),
            )

            if self.use_temporal_conv:
                self.temporal_linear = nn.Sequential(
                    nn.Linear(self.temporal_dim, self.spatial_dim),
                    nn.GELU(),
                    nn.Linear(self.spatial_dim, self.spatial_dim),
                    nn.LayerNorm(self.spatial_dim, eps=1e-6),
                )

            self.mlp = nn.Linear(self.spatial_dim, self.out_dim)

            out_config = deepcopy(config)
            out_config.hidden_size = out_dim
            self.after_norm = RMSNorm(out_config)

    def spatial_conv_reshape(self, x, spatial_conv_size):
        """
        reshape before linear to imitation conv
        """
        S, C = x.shape
        x = x.reshape([-1, C * (spatial_conv_size**2)])
        return x

    def forward(self, x, image_mask, token_type_ids, image_type_ids, grid_thw):
        """
        x: image_features
        image_mask: [B]
        token_types_ids: [B]
        image_type_ids:  [B_image]
        grid_thw: [B_image, 3]
        """
        assert image_type_ids is not None

        def fwd_spatial(x):
            """
            x in the shape of [S, H]
            S is ordered in the following way: [ [patch_h*patch_w (row-major traversal)] * patch_time]
            H is simply hidden
            """
            x = self.spatial_conv_reshape(x, self.spatial_conv_size)

            x = self.spatial_linear(x)

            return x

        def fwd_placeholder(x, grid_thw, to_tensor=False):
            """
            x: [S, H]
            grid_thw: [S, 3]
                the second dimension: [t, h, w]
            """

            grid_thw_cpu = grid_thw.cpu().numpy()
            grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:]
            grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2)

            tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2)
            batch_offset = np.empty(
                tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype
            )
            batch_offset[0] = 0
            batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1]

            assert (
                self.temporal_conv_size == 2
            ), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}"

            # TODO: support any temporal conv size
            slice_offsets = []
            for temporoal_size, spatial_size, b_offset in zip(
                grid_t, grid_hw_after_conv, batch_offset
            ):
                for temp_offset in range(0, temporoal_size, 2):
                    slice_offsets.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
                        )
                    )
            slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to(
                x.device
            )

            slice_offsets2 = []
            for temporoal_size, spatial_size, b_offset in zip(
                grid_t, grid_hw_after_conv, batch_offset
            ):
                for temp_offset in range(
                    1 if temporoal_size > 1 else 0, temporoal_size, 2
                ):
                    slice_offsets2.append(
                        np.arange(
                            b_offset + (temp_offset) * spatial_size,
                            b_offset + (temp_offset + 1) * spatial_size,
                        )
                    )
            slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to(
                x.device
            )

            x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets)
            x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2)
            x = torch.concat([x_timestep_1, x_timestep_2], dim=-1)
            return x

        def fwd_temporal(x):
            x = self.temporal_linear(x)
            return x

        def fwd_mlp(x):
            x = self.mlp(x)
            x = self.after_norm(x)
            return x

        x = fwd_spatial(x)
        if self.use_temporal_conv:
            x = fwd_placeholder(x, grid_thw)
            x = fwd_temporal(x)
        x = fwd_mlp(x)
        return x


class Ernie4_5_MoeVLHead(Ernie4_5_MoeLMHead):
    """Ernie4_5_MoeVLHead"""

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        if config.mm_vocab_size > 0:
            mm_vocab_config = deepcopy(config)
            mm_vocab_config.vocab_size = config.mm_vocab_size
            assert mm_vocab_config.vocab_size > 0, mm_vocab_config
            assert (
                mm_vocab_config.im_patch_id >= mm_vocab_config.max_text_id
            ), mm_vocab_config
            self.mm_head = Ernie4_5_MoeLMHead(mm_vocab_config)
        else:
            self.mm_head = None

    def forward(self, hidden_state, token_type_ids_labels, use_cache=False):
        """
        Args:
            hidden_state(torch.Tensor): hidden state
            token_type_ids_labels(torch.Tensor): token ids
            use_cache(bool): whether to use cache, default is False

        Returns:
            logits_text(torch.Tensor): text logits
            logits_image(torch.Tensor): image logits
        """
        if not use_cache:
            mm_head_weight = self.mm_head.weight if self.mm_head is not None else None
            mm_head_bias = self.mm_head.bias if self.mm_head is not None else None
            logits_text, logits_image, _ = calc_multimodal_logits(
                hidden_state,
                self.weight,
                self.bias,
                mm_head_weight,
                mm_head_bias,
                token_type_ids_labels,
                self.config,
            )
            return logits_text, logits_image, None
        else:
            # TODO,support lm_head decode only
            return (
                parallel_matmul(
                    hidden_state[:, -1:, :],
                    self.weight,
                    self.bias,
                    transpose_y=self.config.tie_word_embeddings,
                ),
                None,
                None,
            )


class Ernie4_5_VLMoeForConditionalGeneration(Ernie4_5_MoeForCausalLM):
    """Ernie4_5_VLMoeForConditionalGeneration"""

    config_class = Ernie4_5_VLMoEConfig
    main_input_name = "pixel_values"
    _keep_in_fp16_modules = ["vision_model"]
    _tp_plan = {}

    def __init__(
        self, config: Ernie4_5_VLMoEConfig, vision_model=None, resampler_model=None
    ):
        """
        initialize Ernie4_5_VLMoeForConditionalGeneration

        Args:
            config(Ernie4_5_VLMoEConfig): Model configuration.
            vision_model(nn.Module): vision model
            resampler_model(nn.Module): resampler model
        """
        super().__init__(config)

        self.vision_model = DFNRopeVisionTransformerPreTrainedModel(
            config.vision_config
        )

        self.model.resampler_model = VariableResolutionResamplerModel(
            config.pixel_hidden_size,
            config.hidden_size,
            config.spatial_conv_size,
            config.temporal_conv_size,
            config=config,
        )

        self.image_preprocess = None
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

    def add_image_preprocess(self, processor):
        """add image preprocess"""
        logger.info("image preprocess is set")

        image_preprocess = processor.image_processor
        image_preprocess.image_mean_tensor = torch.tensor(
            image_preprocess.image_mean, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        image_preprocess.image_std_tensor = torch.tensor(
            image_preprocess.image_std, dtype=torch.float32
        ).reshape([1, 3, 1, 1])
        image_preprocess.rescale_factor = torch.tensor(
            image_preprocess.rescale_factor, dtype=torch.float32
        )
        image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze(
            [-2, -1]
        ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1)
        image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze(
            [-2, -1]
        ).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1)

        self.image_preprocess = image_preprocess

    def vision_forward(
        self,
        images,
        image_position_ids,
        image_attention_mask,
        grid_thw,
    ):
        """vision_forward"""
        if self.image_preprocess is not None:
            assert images.dtype == torch.uint8, images.dtype
            current_device = images.device
            self.image_preprocess.image_mean_tensor = (
                self.image_preprocess.image_mean_tensor.to(current_device)
            )
            self.image_preprocess.image_std_tensor = (
                self.image_preprocess.image_std_tensor.to(current_device)
            )
            images = self.image_preprocess.rescale_factor * images.to(torch.float32)
            images = (
                images - self.image_preprocess.image_mean_tensor
            ) / self.image_preprocess.image_std_tensor
            images = images.to(torch.bfloat16)
        else:
            assert images.dtype == torch.bfloat16, images.dtype
        # logger.info(f"extract feature input - {images}--{grid_thw}")
        if grid_thw is not None:
            grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3])
            grid_thw = F.pad(
                torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0),
                [1, 0, 0, 0],
                value=1,
            )
        image_features = self.vision_model(images, grid_thw)
        return image_features

    def vision_mapping_forward(
        self,
        token_type_ids,
        token_type_ids_w_video,
        input_ids,
        mm_input_ids,
        image_features,
        inputs_embeds,
        image_type_ids,
        grid_thw,
    ):
        """vision_mapping_forward"""
        image_mask = input_ids == self.config.im_patch_id
        image_features = self.model.resampler_model(
            image_features,
            image_mask,
            token_type_ids_w_video,
            image_type_ids,
            grid_thw,
        )

        if image_features.dim == 2:
            B, N, C = image_features.shape
            image_features = image_features.reshape([B * N, C]).to(inputs_embeds.dtype)
        # Will overwrite the part of `ids==im_patch_id` in `mm_ids_features`
        inputs_embeds[image_mask.to(inputs_embeds.device)] = image_features.to(
            inputs_embeds.device
        )
        return inputs_embeds

    def prepare_inputs_for_generation(
        self,
        input_ids,
        images=None,
        use_cache=False,
        past_key_values=None,
        inputs_embeds=None,
        image_position_ids=None,
        image_attention_mask=None,
        token_type_ids=None,
        image_type_ids=None,
        grid_thw=None,
        **kwargs,
    ):
        """
        Prepare inputs for the decoder that can be used for generation.

        Args:
            input_ids (torch.Tensor): Input ids.
            images (torch.Tensor): Images. Default to None.
            use_cache (bool): Whether to use cache. Default to False.
            past_key_values (list): Past key values. Default to None.
            inputs_embeds (torch.Tensor): Input embeddings. Default to None.
            image_position_ids (torch.Tensor): Image position ids. Default to None.
            image_attention_mask (torch.Tensor): Image attention mask. Default to None.
            token_type_ids (torch.Tensor): Token type ids. Default to None.
            image_type_ids (torch.Tensor): Image type ids. Default to None.
            grid_thw (torch.Tensor): Grid thw. Default to None.
        """
        if past_key_values:
            input_ids = input_ids[:, -1:]
            token_type_ids = token_type_ids[:, -1:]
            image_type_ids = (
                image_type_ids[:, -1:] if image_type_ids is not None else None
            )

        if self.config.use_flash_attention:
            attention_mask = None
        else:
            attention_mask = kwargs.get("attention_mask", None)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": True,
                "attention_mask": attention_mask,
                "images": images,
                "image_position_ids": image_position_ids,
                "image_attention_mask": image_attention_mask,
                "image_type_ids": image_type_ids,
                "token_type_ids": torch.cat(
                    [
                        token_type_ids,
                        torch.zeros(
                            [len(token_type_ids), 1], dtype=token_type_ids.dtype
                        ).to(token_type_ids.device),
                    ],
                    dim=-1,
                ),
                "grid_thw": grid_thw,
            }
        )
        if self.config.rope_3d:
            model_inputs.update({"position_ids": kwargs["position_ids"]})

        return model_inputs

    def _post_init(self, original_init, *args, **kwargs):
        """
        Label all multimodal parameters in the model, only head and Embedding
        Experts parameters are already labeled
        """
        super()._post_init(self, original_init, *args, **kwargs)
        if self.lm_head.mm_head is not None:
            self.lm_head.mm_head.weight.expert_type = "expert_type_1"
        if getattr(self.lm_head.mm_head, "bias", None) is not None:
            self.lm_head.mm_head.bias.expert_type = "expert_type_1"

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        labels: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        ignored_index: Optional[int] = 0,
        return_dict: Optional[bool] = None,
        image_position_ids: Optional[torch.Tensor] = None,
        image_attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        image_type_ids: Optional[torch.Tensor] = None,
        grid_thw: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """
        Forward for Ernie4_5_VLMoeForConditionalGeneration

        Args:
            input_ids (torch.Tensor): Input ids.
            position_ids (Optional[torch.Tensor], optional): Position ids. Defaults to None.
            attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
            past_key_values (Optional[List[torch.Tensor]], optional): Past key values. Defaults to None.
            use_cache (Optional[bool], optional): Use cache. Defaults to None.
            output_attentions (Optional[bool], optional): Output attentions. Defaults to None.
            output_hidden_states (Optional[bool], optional): Output hidden states. Defaults to None.
            labels (Optional[torch.Tensor], optional): Labels. Defaults to None.
            images (Optional[torch.Tensor]): Images. Defaults to None.
            ignored_index (Optional[int], optional): Ignored index. Defaults to 0.
            return_dict (Optional[bool], optional): Return dict. Defaults to None.
            image_position_ids (Optional[torch.Tensor], optional): Image position ids. Defaults to None.
            image_attention_mask (Optional[torch.Tensor], optional): Image attention mask. Defaults to None.
            token_type_ids (Optional[torch.Tensor], optional): Token type ids. Defaults to None.
            image_type_ids (Optional[torch.Tensor], optional): Image type ids. Defaults to None.
            grid_thw (Optional[torch.Tensor], optional): Grid thw. Defaults to None.
        """
        if grid_thw is not None:
            grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3])
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        image_mask = input_ids == self.config.im_patch_id

        image_rate = image_mask.to(torch.float32).mean()

        if past_key_values is None:
            if images is not None:
                assert (image_mask).any().item(), (
                    image_mask.detach().cpu().numpy().tolist(),
                    input_ids.detach().cpu().numpy().tolist(),
                    self.config.im_patch_id,
                    images.shape,
                )
                image_features = self.vision_forward(
                    images,
                    image_position_ids,
                    image_attention_mask,
                    grid_thw,
                )
            else:
                image_features = None  # no more faking
        else:
            image_features = None
        if token_type_ids is None:
            token_type_ids = image_mask.to(torch.int64)
            token_type_ids_labels = torch.cat(
                [token_type_ids[:, 1:], token_type_ids[:, -1:]], 1
            )
        else:
            assert (
                token_type_ids.shape[1] == input_ids.shape[1] + 1
            ), f"token_type:{token_type_ids.shape}, ids:{input_ids.shape}"
            token_type_ids_labels = token_type_ids[..., 1:]

        lm_input_ids = input_ids.clone()
        mm_input_ids = input_ids.clone()

        inputs_embeds = self.model.embed_tokens(lm_input_ids)
        token_type_ids_w_video = token_type_ids[..., :-1].clone()
        token_type_ids[token_type_ids == TokenType.video] = TokenType.image

        if images is not None and image_features is not None:
            inputs_embeds = self.vision_mapping_forward(
                token_type_ids[..., :-1],
                token_type_ids_w_video,
                input_ids,
                mm_input_ids,
                image_features,
                inputs_embeds,
                image_type_ids,
                grid_thw,
            )
        else:
            pass  # do nothing, should not hang under DygraphShardingOptimizerV2

        outputs = self.model(
            position_ids=position_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
        )

        if not use_cache:
            assert outputs.last_hidden_state.shape[:2] == token_type_ids_labels.shape, (
                outputs.last_hidden_state.shape,
                token_type_ids_labels.shape,
            )
            if self.config.use_recompute_loss_fn:
                logits = outputs.last_hidden_state
            else:
                logits = self.lm_head(outputs.last_hidden_state)
        else:
            logits = self.lm_head(outputs.last_hidden_state[:, -1:, :])

        router_loss = outputs.router_loss

        # aka Generate Decoding
        loss = None
        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            router_loss=outputs.router_loss,
        )

    @staticmethod
    def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
        """_resolve_prefix_keys"""
        # state_keys_map base to real
        state_keys_map = {}

        state_keys_base = set(state_keys_base)
        state_keys_real = set(state_keys_real)

        for key in state_keys_base:
            for x in state_keys_real:
                if "mm_embed_tokens" in x:
                    if "mm_embed_tokens" in key:
                        state_keys_map[key] = x
                        break
                elif x.endswith(key):
                    state_keys_map[key] = x
                    break
            if key not in state_keys_map:
                if not ignore_error:
                    logger.error(f"could not find name {key} in loaded state dict!")
            else:
                state_keys_real.remove(state_keys_map[key])

        return state_keys_map


@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
    """
    Base class for model outputs with past key values and cross attention layers,
    with additional support for router components in mixture-of-experts models.

    This extends the base model output to include:
    1. Router-related outputs for expert selection
    2. Maintains all existing functionality from the parent class
    """

    last_hidden_state: Optional[Tuple[torch.Tensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None
    hidden_states: Optional[Tuple[torch.Tensor]] = None
    attentions: Optional[Tuple[torch.Tensor]] = None
    cross_attentions: Optional[Tuple[torch.Tensor]] = None
    router_loss: Optional[torch.Tensor] = None
    gate_logits: Optional[Tuple[torch.Tensor]] = None


@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True`
            is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or
            when `config.output_attentions=True`):
            Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        router_loss (Optional[torch.Tensor]):
            The routing loss computed by the gating network in mixture-of-experts models.
            This is typically the load balancing loss that encourages equal expert utilization.
            None when not using mixture-of-experts routing.
    """

    loss: Optional[torch.Tensor] = None
    logits: torch.Tensor = None
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None
    hidden_states: Optional[Tuple[torch.Tensor]] = None
    attentions: Optional[Tuple[torch.Tensor]] = None
    router_loss: Optional[Tuple[torch.Tensor]] = None