#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

#           This file was automatically generated from src/transformers/models/gpt_oss/modular_gpt_oss.py.

#               Do NOT edit this file manually as any edits will be overwritten by the generation of

#             the file from the modular. If any change should be done, please apply the change to the

#                          modular_gpt_oss.py file directly. One of our CI enforces this.

#                🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

# Copyright 2025 The HuggingFace Team. All rights reserved.

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

# Copyright (c) 2025, HUAWEI CORPORATION.  All rights reserved.



import os

from typing import Callable, Optional, Union



import torch

try:

    import torch_npu

except ImportError:

    pass

from torch import nn

from torch.nn import functional as F



from transformers.cache_utils import Cache, DynamicCache

from transformers.generation import GenerationMixin

from transformers.integrations.hub_kernels import use_kernel_forward_from_hub

from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask

from transformers.modeling_layers import (

    GenericForSequenceClassification,

    GenericForTokenClassification,

    GradientCheckpointingLayer,

)

from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast

from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel

from transformers.processing_utils import Unpack

from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple

from transformers.utils.deprecation import deprecate_kwarg

try:

    from transformers.utils.generic import OutputRecorder, check_model_inputs

except ImportError:

    # adapt for transformers 5.x

    from transformers.utils.output_capturing import OutputRecorder, capture_outputs

    check_model_inputs = capture_outputs

from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig



from mindspeed.ops.grouped_matmul import eager_grouped_matmul, fused_grouped_matmul

from mindspeed_llm.fsdp2.utils.global_vars import get_args



try:

    from einops import rearrange

except ImportError:

    rearrange = None





@use_kernel_forward_from_hub("RMSNorm")

class GptOssRMSNorm(nn.Module):

    def __init__(self, hidden_size, eps=1e-6):

        """

        GptOssRMSNorm is equivalent to T5LayerNorm

        """

        super().__init__()

        self.weight = nn.Parameter(torch.ones(hidden_size))

        self.variance_epsilon = eps



    def forward(self, hidden_states):

        args = get_args()

        if args.use_fused_rmsnorm:

            output = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]

        else:

            input_dtype = hidden_states.dtype

            hidden_states = hidden_states.to(torch.float32)

            variance = hidden_states.pow(2).mean(-1, keepdim=True)

            hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

            output = (self.weight * hidden_states).to(input_dtype)

        return output



    def extra_repr(self):

        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"





class GmmFunction(torch.autograd.Function):

    @staticmethod

    def forward(ctx, x, weight, group_list):

        ctx.save_for_backward(x, weight)

        ctx.group_list = group_list



        fwd_output = torch_npu.npu_grouped_matmul(

            [x], [weight],

            bias=None,

            group_list=group_list,

            split_item=2, group_type=0, group_list_type=1

        )[0]

        return fwd_output



    @staticmethod

    def backward(ctx, grad_output):

        input_tensor, weight = ctx.saved_tensors

        group_list = ctx.group_list



        weight_t = torch.transpose(weight, 1, 2)

        grad_input = torch_npu.npu_grouped_matmul(

            [grad_output], [weight_t],

            bias=None,

            group_list=group_list,

            split_item=2, group_type=0, group_list_type=1

        )[0]



        grad_weight = torch_npu.npu_grouped_matmul(

            [input_tensor.t()], [grad_output],

            bias=None,

            group_list=group_list,

            split_item=3, group_type=2, group_list_type=1

        )[0]



        return grad_input, grad_weight, None





def npu_group_gemm(x, weight, group_list):

    return GmmFunction.apply(x, weight, group_list)





class GptOssFusedExperts(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.intermediate_size = config.intermediate_size

        self.num_experts = config.num_local_experts

        self.hidden_size = config.hidden_size

        self.expert_dim = self.intermediate_size



        self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))

        self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))

        self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))

        self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))



        self.alpha = 1.702

        self.limit = 7.0



    def forward(self, hidden_states, router_indices, routing_weights):

        batch_size, seq_len, _ = hidden_states.shape



        hidden_states = hidden_states.view(-1, self.hidden_size)



        permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(

            hidden_states,

            router_indices.to(torch.int32)

        )



        tokens_per_expert = torch.histc(

            router_indices.float(),

            bins=self.num_experts,

            min=0,

            max=self.num_experts - 1

        ).long()



        intermediate_gemm = npu_group_gemm(

            permuted_hidden_states,

            self.gate_up_proj,

            tokens_per_expert

        )



        expanded_gate_up_bias = self.gate_up_proj_bias.repeat_interleave(tokens_per_expert, dim=0)



        gate_up = intermediate_gemm + expanded_gate_up_bias



        gate_up_view = gate_up.view(gate_up.shape[0], self.expert_dim, 2)



        gate = gate_up_view[:, :, 0]

        up = gate_up_view[:, :, 1]



        gate = gate.clamp(min=None, max=self.limit)

        up = up.clamp(min=-self.limit, max=self.limit)



        glu = gate * torch.sigmoid(gate * self.alpha)



        intermediate_activations = (up + 1.0) * glu



        output_gemm = npu_group_gemm(

            intermediate_activations,

            self.down_proj,

            tokens_per_expert

        )



        expanded_down_bias = self.down_proj_bias.repeat_interleave(tokens_per_expert, dim=0)

        output = output_gemm + expanded_down_bias



        next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, routing_weights)



        return next_states.view(batch_size, -1, self.hidden_size)



    def ep_forward(self, hidden_states, split_list):

    

        def as_local(t):

            return t.to_local() if hasattr(t, "to_local") else t

        gate_up_weights = as_local(self.gate_up_proj)

        down_weights = as_local(self.down_proj)

        gate_up_bias = as_local(self.gate_up_proj_bias)

        down_bias = as_local(self.down_proj_bias)

    

        split_list_tensor = None

        if gate_up_bias is not None or down_bias is not None:

            if isinstance(split_list, list):

                split_list_tensor = torch.tensor(split_list, device=hidden_states.device)

            else:

                split_list_tensor = split_list



        gate_up_out = fused_grouped_matmul(hidden_states, split_list, gate_up_weights)

    

        # Add Bias

        if gate_up_bias is not None:

            expanded_bias = gate_up_bias.repeat_interleave(split_list_tensor, dim=0)

            gate_up_out = gate_up_out + expanded_bias

    

        # --- Activation (SwiGLU) ---

        gate_up_view = gate_up_out.view(gate_up_out.shape[0], self.expert_dim, 2)

    

        gate = gate_up_view[:, :, 0]

        up = gate_up_view[:, :, 1]

    

        gate = gate.clamp(min=None, max=self.limit)

        up = up.clamp(min=-self.limit, max=self.limit)

        glu = gate * torch.sigmoid(gate * self.alpha)

        intermediate_activations = (up + 1.0) * glu



        output = fused_grouped_matmul(intermediate_activations, split_list, down_weights)

    

        # Add Bias

        if down_bias is not None:

            expanded_down_bias = down_bias.repeat_interleave(split_list_tensor, dim=0)

            output = output + expanded_down_bias

    

        return output





class GptOssExperts(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.intermediate_size = config.intermediate_size

        self.num_experts = config.num_local_experts

        self.hidden_size = config.hidden_size

        self.expert_dim = self.intermediate_size

        self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))

        self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.expert_dim))

        self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))

        self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size))

        self.alpha = 1.702

        self.limit = 7.0



    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:

        """

        When training it is more efficient to just loop over the experts and compute the output for each expert

        as otherwise the memory would explode.



        For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.



        Args:

            hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size)

            selected_experts (torch.Tensor): (batch_size * token_num, top_k)

            routing_weights (torch.Tensor): (batch_size * token_num, num_experts)

        Returns:

            torch.Tensor

        """

        batch_size = hidden_states.shape[0]

        hidden_states = hidden_states.reshape(-1, self.hidden_size)  # (num_tokens, hidden_size)

        num_experts = routing_weights.shape[1]

        if hidden_states.device.type == "cpu" or self.training:

            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)

            with torch.no_grad():

                expert_mask = torch.nn.functional.one_hot(

                    router_indices, num_classes=num_experts + 1

                )  # masking is also a class

                expert_mask = expert_mask.permute(2, 1, 0)

                # we sum on the top_k and on the sequence length to get which experts

                # are hit this time around

                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

            for expert_idx in expert_hit[:]:

                # expert_idx only have 1 element, so we can use scale for fast indexing

                expert_idx = expert_idx[0]

                # skip masking index

                if expert_idx == num_experts:

                    continue

                with torch.no_grad():

                    _, token_idx = torch.where(expert_mask[expert_idx])

                current_state = hidden_states[token_idx]

                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]

                gate, up = gate_up[..., ::2], gate_up[..., 1::2]

                gate = gate.clamp(min=None, max=self.limit)

                up = up.clamp(min=-self.limit, max=self.limit)

                glu = gate * torch.sigmoid(gate * self.alpha)

                gated_output = (up + 1) * glu

                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]

                weighted_output = out * routing_weights[token_idx, expert_idx, None]

                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))

            next_states = next_states.view(batch_size, -1, self.hidden_size)

        else:

            hidden_states = hidden_states.repeat(num_experts, 1)

            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)

            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]

            gate, up = gate_up[..., ::2], gate_up[..., 1::2]

            gate = gate.clamp(min=None, max=self.limit)

            up = up.clamp(min=-self.limit, max=self.limit)

            glu = gate * torch.sigmoid(gate * self.alpha)

            next_states = torch.bmm(((up + 1) * glu), self.down_proj)

            next_states = next_states + self.down_proj_bias[..., None, :]

            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)

            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]

            next_states = next_states.sum(dim=0)

        return next_states



    def ep_forward(self, hidden_states, split_list):

    

        def as_local(t):

            return t.to_local() if hasattr(t, "to_local") else t

    

        gate_up_weights = as_local(self.gate_up_proj)

        down_weights = as_local(self.down_proj)

        gate_up_bias = as_local(self.gate_up_proj_bias)

        down_bias = as_local(self.down_proj_bias)

    

        split_list_tensor = None

        if gate_up_bias is not None or down_bias is not None:

            if isinstance(split_list, list):

                split_list_tensor = torch.tensor(split_list, device=hidden_states.device)

            else:

                split_list_tensor = split_list

        gate_up_out = eager_grouped_matmul(hidden_states, split_list, gate_up_weights)

    

        # Add Bias

        if gate_up_bias is not None:

            expanded_bias = gate_up_bias.repeat_interleave(split_list_tensor, dim=0)

            gate_up_out = gate_up_out + expanded_bias

    

        # --- Activation (SwiGLU) ---

        gate_up_view = gate_up_out.view(gate_up_out.shape[0], self.expert_dim, 2)

    

        gate = gate_up_view[:, :, 0]

        up = gate_up_view[:, :, 1]

    

        gate = gate.clamp(min=None, max=self.limit)

        up = up.clamp(min=-self.limit, max=self.limit)

        glu = gate * torch.sigmoid(gate * self.alpha)

        intermediate_activations = (up + 1.0) * glu



        output = eager_grouped_matmul(intermediate_activations, split_list, down_weights)

    

        # Add Bias

        if down_bias is not None:

            expanded_down_bias = down_bias.repeat_interleave(split_list_tensor, dim=0)

            output = output + expanded_down_bias

    

        return output





class GptOssTopKRouter(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.top_k = config.num_experts_per_tok

        self.num_experts = config.num_local_experts

        self.hidden_dim = config.hidden_size

        self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim))

        self.bias = nn.Parameter(torch.empty(self.num_experts))



    def forward(self, hidden_states):

        hidden_states = hidden_states.reshape(-1, self.hidden_dim)

        router_logits = F.linear(hidden_states, self.weight, self.bias)  # (seq_len, num_experts)

        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)  # (seq_len, top_k)

        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)

        args = get_args()

        if args.moe_grouped_gemm:

            router_scores = router_top_value

        else:

            router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)



        return router_scores, router_indices





@use_kernel_forward_from_hub("MegaBlocksMoeMLP")

class GptOssMLP(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.router = GptOssTopKRouter(config)

        args = get_args()

        if args.moe_grouped_gemm or args.ep_dispatcher == 'fused':

            self.experts = GptOssFusedExperts(config)

        else:

            self.experts = GptOssExperts(config)



    def forward(self, hidden_states):

        router_scores, router_indices = self.router(hidden_states)  # (num_experts, seq_len)

        routed_out = self.experts(hidden_states, router_indices, router_scores)

        return routed_out, router_scores





class GptOssRotaryEmbedding(nn.Module):

    inv_freq: torch.Tensor  # fix linting for `register_buffer`



    def __init__(self, config: GptOssConfig, device=None):

        super().__init__()

        # BC: "rope_type" was originally "type"

        if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):

            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))

        else:

            self.rope_type = "default"

        self.max_seq_len_cached = config.max_position_embeddings

        self.original_max_seq_len = config.max_position_embeddings



        self.config = config

        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]



        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.original_inv_freq = self.inv_freq



    @torch.no_grad()

    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)

    def forward(self, x, position_ids):

        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)

        position_ids_expanded = position_ids[:, None, :].float()



        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):  # Force float32

            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)

            args = get_args()

            if args.use_fused_rotary_pos_emb:

                emb = torch.cat((freqs, freqs), dim=-1)

            else:

                emb = freqs



            cos = emb.cos() * self.attention_scaling

            sin = emb.sin() * self.attention_scaling



        return cos.to(x.dtype), sin.to(x.dtype)





def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

    """

    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,

    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

    """

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape

    if n_rep == 1:

        return hidden_states

    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)

    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)





def _apply_rotary_emb(

        x: torch.Tensor,

        cos: torch.Tensor,

        sin: torch.Tensor,

) -> torch.Tensor:

    first_half, second_half = torch.chunk(x, 2, dim=-1)

    first_ = first_half * cos - second_half * sin

    second_ = second_half * cos + first_half * sin

    return torch.cat((first_, second_), dim=-1)





def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)

    sin = sin.unsqueeze(unsqueeze_dim)

    args = get_args()

    if args.use_fused_rotary_pos_emb:

        q_embed = torch_npu.npu_rotary_mul(q, cos, sin)

        k_embed = torch_npu.npu_rotary_mul(k, cos, sin)

    else:

        q_embed = _apply_rotary_emb(q, cos, sin)

        k_embed = _apply_rotary_emb(k, cos, sin)



    return q_embed, k_embed





def flash_attention_forward(

        module: nn.Module,

        query: torch.Tensor,

        key: torch.Tensor,

        value: torch.Tensor,

        attention_mask: Optional[torch.Tensor],

        scaling: float,

        dropout: float = 0.0,

        sliding_window: int = None,

        **kwargs,

):



    pre_tokens = 1048576

    next_tokens = 0

    bsz, n_head, seq_length, head_dim = (

        query.shape[0], query.shape[1], query.shape[2], query.shape[3])



    sparse_mode = 4

    shape_order = "BNSD"



    if sliding_window:

        pre_tokens = sliding_window



    # When sparse_mode is 2 or 4, a compressed mask of [2048, 2048] should be passed.

    new_mask = torch.ones((2048, 2048), device=torch.accelerator.current_device(), dtype=torch.bool)

    atten_mask = torch.triu(new_mask, diagonal=1)



    attn_output = torch_npu.npu_fusion_attention_v2(

        query, key, value,

        n_head,

        shape_order,

        pse=None,

        sparse_mode=sparse_mode,

        sink=module.sinks.float(),

        atten_mask=atten_mask,

        scale=scaling,

        pre_tokens=pre_tokens,

        next_tokens=next_tokens,

        keep_prob=1 - dropout,

    )[0]



    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None





def eager_attention_forward(

        module: nn.Module,

        query: torch.Tensor,

        key: torch.Tensor,

        value: torch.Tensor,

        attention_mask: Optional[torch.Tensor],

        scaling: float,

        dropout: float = 0.0,

        **kwargs,

):

    key_states = repeat_kv(key, module.num_key_value_groups)

    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if attention_mask is not None:

        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        attn_weights = attn_weights + causal_mask



    sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)

    combined_logits = torch.cat([attn_weights, sinks], dim=-1)



    # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16

    # when training with bsz>1 we clamp max values.



    combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values

    probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)

    scores = probs[..., :-1]  # we drop the sink here

    attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value_states)

    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights





class GptOssAttention(nn.Module):

    """Multi-headed attention from 'Attention Is All You Need' paper"""



    def __init__(self, config: GptOssConfig, layer_idx: int):

        super().__init__()



        self.config = config

        self.layer_idx = layer_idx

        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)

        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads

        self.scaling = self.head_dim ** -0.5

        self.attention_dropout = config.attention_dropout

        self.is_causal = True

        self.q_proj = nn.Linear(

            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias

        )

        self.k_proj = nn.Linear(

            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias

        )

        self.v_proj = nn.Linear(

            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias

        )

        self.o_proj = nn.Linear(

            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias

        )

        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None

        self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))



    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")

    def forward(

            self,

            hidden_states: torch.Tensor,

            position_embeddings: tuple[torch.Tensor, torch.Tensor],

            attention_mask: Optional[torch.Tensor],

            past_key_values: Optional[Cache] = None,

            cache_position: Optional[torch.LongTensor] = None,

            **kwargs: Unpack[TransformersKwargs],

    ) -> tuple[torch.Tensor, torch.Tensor]:

        input_shape = hidden_states.shape[:-1]

        hidden_shape = (*input_shape, -1, self.head_dim)



        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)



        cos, sin = position_embeddings

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)



        if past_key_values is not None:

            cache_kwargs = {"cache_position": cache_position}

            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)



        attention_interface: Callable = eager_attention_forward

        if self.config._attn_implementation != "eager":

            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]



        args = get_args()

        if args.use_flash_attn:

            attn_output, attn_weights = flash_attention_forward(

                self,

                query_states,

                key_states,

                value_states,

                attention_mask,

                dropout=0.0 if not self.training else self.attention_dropout,

                scaling=self.scaling,

                sliding_window=self.sliding_window,

                s_aux=self.sinks,

                **kwargs,

            )

        else:

            attn_output, attn_weights = attention_interface(

                self,

                query_states,

                key_states,

                value_states,

                attention_mask,

                dropout=0.0 if not self.training else self.attention_dropout,

                scaling=self.scaling,

                sliding_window=self.sliding_window,

                s_aux=self.sinks,  # diff with Llama

                **kwargs,

            )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()

        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights





class GptOssDecoderLayer(GradientCheckpointingLayer):

    def __init__(self, config: GptOssConfig, layer_idx: int):

        super().__init__()

        self.hidden_size = config.hidden_size

        self.self_attn = GptOssAttention(config=config, layer_idx=layer_idx)

        self.mlp = GptOssMLP(config)

        self.input_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.attention_type = config.layer_types[layer_idx]



    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")

    def forward(

            self,

            hidden_states: torch.Tensor,

            attention_mask: Optional[torch.Tensor] = None,

            position_ids: Optional[torch.LongTensor] = None,

            past_key_values: Optional[Cache] = None,

            use_cache: Optional[bool] = False,

            cache_position: Optional[torch.LongTensor] = None,

            position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC

            **kwargs: Unpack[TransformersKwargs],

    ) -> torch.Tensor:

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention

        hidden_states, _ = self.self_attn(

            hidden_states=hidden_states,

            attention_mask=attention_mask,

            position_ids=position_ids,

            past_key_values=past_key_values,

            use_cache=use_cache,

            cache_position=cache_position,

            position_embeddings=position_embeddings,

            **kwargs,

        )

        hidden_states = residual + hidden_states

        # Fully Connected

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states, _ = self.mlp(hidden_states)  # diff with llama: router scores

        hidden_states = residual + hidden_states

        return hidden_states





@auto_docstring

class GptOssPreTrainedModel(PreTrainedModel):

    config: GptOssConfig

    base_model_prefix = "model"

    supports_gradient_checkpointing = True

    _no_split_modules = ["GptOssDecoderLayer"]

    _skip_keys_device_placement = ["past_key_values"]

    _supports_flash_attn = True

    _supports_sdpa = False

    _supports_flex_attn = True



    _can_compile_fullgraph = True

    _supports_attention_backend = True

    _can_record_outputs = {

        "router_logits": OutputRecorder(GptOssTopKRouter, index=0),

        "hidden_states": GptOssDecoderLayer,

        "attentions": GptOssAttention,

    }

    _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"]

    _supports_flash_attention = False

    _supports_flex_attention = False



    def _init_weights(self, module):

        std = self.config.initializer_range

        if isinstance(module, nn.Linear):

            module.weight.data.normal_(mean=0.0, std=std)

            if module.bias is not None:

                module.bias.data.zero_()

        elif isinstance(module, nn.Parameter):

            module.data.normal_(mean=0.0, std=std)

        elif isinstance(module, nn.Embedding):

            module.weight.data.normal_(mean=0.0, std=std)

        elif isinstance(module, GptOssRMSNorm):

            module.weight.data.fill_(1.0)

        elif isinstance(module, GptOssExperts):

            module.gate_up_proj.data.normal_(mean=0.0, std=std)

            module.gate_up_proj_bias.data.zero_()

            module.down_proj.data.normal_(mean=0.0, std=std)

            module.down_proj_bias.data.zero_()

        elif isinstance(module, GptOssAttention):

            module.sinks.data.normal_(mean=0.0, std=std)

        elif isinstance(module, GptOssTopKRouter):

            module.weight.data.normal_(mean=0.0, std=std)

            module.bias.data.normal_(mean=0.0, std=std)





@auto_docstring

class GptOssModel(GptOssPreTrainedModel):

    _no_split_modules = ["GptOssDecoderLayer"]



    def __init__(self, config: GptOssConfig):

        super().__init__(config)

        self.padding_idx = config.pad_token_id

        self.vocab_size = config.vocab_size



        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)

        self.layers = nn.ModuleList(

            [GptOssDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]

        )

        self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.rotary_emb = GptOssRotaryEmbedding(config=config)

        self.gradient_checkpointing = False



        # Initialize weights and apply final processing

        self.post_init()



    @check_model_inputs

    @auto_docstring

    def forward(

            self,

            input_ids: Optional[torch.LongTensor] = None,

            attention_mask: Optional[torch.Tensor] = None,

            position_ids: Optional[torch.LongTensor] = None,

            past_key_values: Optional[Cache] = None,

            inputs_embeds: Optional[torch.FloatTensor] = None,

            use_cache: Optional[bool] = None,

            cache_position: Optional[torch.LongTensor] = None,

            **kwargs: Unpack[TransformersKwargs],

    ) -> MoeModelOutputWithPast:

        if (input_ids is None) ^ (inputs_embeds is not None):

            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")



        if use_cache and past_key_values is None:

            past_key_values = DynamicCache(config=self.config)



        if inputs_embeds is None:

            inputs_embeds = self.embed_tokens(input_ids)



        if cache_position is None:

            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

            cache_position = torch.arange(

                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device

            )

        if position_ids is None:

            position_ids = cache_position.unsqueeze(0)



        # It may already have been prepared by e.g. `generate`

        if not isinstance(causal_mask_mapping := attention_mask, dict):

            mask_kwargs = {

                "config": self.config,

                "input_embeds": inputs_embeds,

                "attention_mask": attention_mask,

                "cache_position": cache_position,

                "past_key_values": past_key_values,

            }

            causal_mask_mapping = {

                "full_attention": create_causal_mask(**mask_kwargs),

                "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),

            }



        hidden_states = inputs_embeds

        position_embeddings = self.rotary_emb(hidden_states, position_ids)



        for decoder_layer in self.layers:

            hidden_states = decoder_layer(

                hidden_states,

                attention_mask=causal_mask_mapping[decoder_layer.attention_type],

                position_ids=position_ids,

                past_key_values=past_key_values,

                use_cache=use_cache,

                cache_position=cache_position,

                position_embeddings=position_embeddings,

                **kwargs,

            )

        hidden_states = self.norm(hidden_states)

        return MoeModelOutputWithPast(

            last_hidden_state=hidden_states,

            past_key_values=past_key_values,

        )





def load_balancing_loss_func(

        gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],

        num_experts: Optional[int] = None,

        top_k=2,

        attention_mask: Optional[torch.Tensor] = None,

) -> Union[torch.Tensor, int]:

    r"""

    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.



    See Switch Transformer (https://huggingface.co/papers/2101.03961) for more details. This function implements the loss

    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between

    experts is too unbalanced.



    Args:

        gate_logits:

            Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of

            shape [batch_size X sequence_length, num_experts].

        num_experts:

            Number of experts

        top_k:

            The number of experts to route per-token, can be also interpreted as the `top-k` routing

            parameter.

        attention_mask (`torch.Tensor`, *optional*):

            The attention_mask used in forward function

            shape [batch_size X sequence_length] if not None.



    Returns:

        The auxiliary loss.

    """

    if gate_logits is None or not isinstance(gate_logits, tuple):

        return 0



    if isinstance(gate_logits, tuple):

        compute_device = gate_logits[0].device

        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)



    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)



    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)



    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)



    if attention_mask is None:

        # Compute the percentage of tokens routed to each experts

        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)



        # Compute the average probability of routing to these experts

        router_prob_per_expert = torch.mean(routing_weights, dim=0)

    else:

        batch_size, sequence_length = attention_mask.shape

        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)



        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask

        expert_attention_mask = (

            attention_mask[None, :, :, None, None]

            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))

            .reshape(-1, top_k, num_experts)

            .to(compute_device)

        )



        # Compute the percentage of tokens routed to each experts

        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(

            expert_attention_mask, dim=0

        )



        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert

        router_per_expert_attention_mask = (

            attention_mask[None, :, :, None]

            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))

            .reshape(-1, num_experts)

            .to(compute_device)

        )



        # Compute the average probability of routing to these experts

        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(

            router_per_expert_attention_mask, dim=0

        )



    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))

    return overall_loss * num_experts





@auto_docstring

class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):

    _tied_weights_keys = ["lm_head.weight"]

    _tp_plan = {"lm_head": "colwise_rep"}

    _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}



    def __init__(self, config):

        super().__init__(config)

        self.model = GptOssModel(config)

        self.vocab_size = config.vocab_size

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.router_aux_loss_coef = config.router_aux_loss_coef

        self.num_experts = config.num_local_experts

        self.num_experts_per_tok = config.num_experts_per_tok



        # Initialize weights and apply final processing

        self.post_init()



    @can_return_tuple

    @auto_docstring

    def forward(

            self,

            input_ids: Optional[torch.LongTensor] = None,

            attention_mask: Optional[torch.Tensor] = None,

            position_ids: Optional[torch.LongTensor] = None,

            past_key_values: Optional[Cache] = None,

            inputs_embeds: Optional[torch.FloatTensor] = None,

            labels: Optional[torch.LongTensor] = None,

            use_cache: Optional[bool] = None,

            output_router_logits: Optional[bool] = None,

            cache_position: Optional[torch.LongTensor] = None,

            logits_to_keep: Union[int, torch.Tensor] = 0,

            **kwargs: Unpack[TransformersKwargs],

    ) -> MoeCausalLMOutputWithPast:

        r"""

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.



        Example:



        ```python

        >>> from transformers import AutoTokenizer, GptOssForCausalLM



        >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")

        >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")



        >>> prompt = "Hey, are you conscious? Can you talk to me?"

        >>> inputs = tokenizer(prompt, return_tensors="pt")



        >>> # Generate

        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)

        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."

        ```"""



        output_router_logits = (

            output_router_logits if output_router_logits is not None else self.config.output_router_logits

        )



        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)

        outputs: MoeModelOutputWithPast = self.model(

            input_ids=input_ids,

            attention_mask=attention_mask,

            position_ids=position_ids,

            past_key_values=past_key_values,

            inputs_embeds=inputs_embeds,

            use_cache=use_cache,

            output_router_logits=output_router_logits,

            cache_position=cache_position,

            **kwargs,

        )



        hidden_states = outputs.last_hidden_state

        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep

        logits = self.lm_head(hidden_states[:, slice_indices, :])



        loss = None

        if labels is not None:

            loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)



        aux_loss = None

        if output_router_logits:

            aux_loss = load_balancing_loss_func(

                outputs.router_logits,

                self.num_experts,

                self.num_experts_per_tok,

                attention_mask,

            )

            if labels is not None:

                loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device



        return MoeCausalLMOutputWithPast(

            loss=loss,

            aux_loss=aux_loss,

            logits=logits,

            past_key_values=outputs.past_key_values,

            hidden_states=outputs.hidden_states,

            attentions=outputs.attentions,

            router_logits=outputs.router_logits,

        )





class GptOssForSequenceClassification(GenericForSequenceClassification, GptOssPreTrainedModel):

    pass





class GptOssForTokenClassification(GenericForTokenClassification, GptOssPreTrainedModel):

    pass





__all__ = [

    "GptOssForCausalLM",

    "GptOssForSequenceClassification",

    "GptOssForTokenClassification",

    "GptOssModel",

    "GptOssPreTrainedModel",

]