import torch
from torch import nn
from transformers.activations import ACT2FN
from mindspeed_mm.fsdp.ops.moe_ops.gemm import grouped_matmul
from mindspeed_mm.fsdp.ops.moe_ops.permute import permute
from mindspeed_mm.fsdp.ops.moe_ops.unpermute import unpermute
from mindspeed_mm.fsdp.ops.swiglu import swiglu
from . import modeling_qwen3_omni_moe
class Qwen3OmniMoeThinkerTextExpertsGemm(nn.ModuleList):
"""
ModuleList of experts.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size \
if config.moe_intermediate_size is None else config.moe_intermediate_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.expert_dim, self.hidden_size))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor
) -> torch.Tensor:
"""
Args:
hidden_states: (batch_size * sequence_length, hidden_dim)
selected_experts: (batch_size * sequence_length, top_k)
routing_weights: (batch_size * sequence_length, top_k)
Returns:
(batch_size * sequence_length, hidden_dim)
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = permute(hidden_states, top_k_index.to(torch.int32), fused=True)
tokens_per_expert = torch.histc(top_k_index, bins=self.num_experts, min=0, max=self.num_experts)
intermediate_hidden_states = grouped_matmul(permuted_hidden_states, self.gate_up_proj, tokens_per_expert, fused=True)
intermediate_activations = swiglu(intermediate_hidden_states, dim=-1, fused=True)
output = grouped_matmul(intermediate_activations, self.down_proj, tokens_per_expert, fused=True)
next_states = unpermute(output, row_ids_map, probs=top_k_weights, fused=True)
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
def apply_qwen3_omni_moe_npu_patch():
import torch.nn.functional as F
from mindspeed_mm.fsdp.ops.npu_patch import npu_fused_operator
modeling_qwen3_omni_moe.apply_rotary_pos_emb_vision = npu_fused_operator.apply_transformers_vision_rope_half_npu
modeling_qwen3_omni_moe.apply_rotary_pos_emb = npu_fused_operator.apply_transformers_rope_half_npu
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextRMSNorm.forward = npu_fused_operator.rms_norm_forward_npu
F.gelu = npu_fused_operator.apply_gelu_npu