from importlib.metadata import version as get_version
import torch
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.utils import logging
from transformers.activations import ACT2FN
import verl.third_party.vllm as vllm_sleep_level
from torch import nn
if get_version("transformers") > "4.57.1":
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen3 import modeling_qwen3
from transformers.models.qwen3_moe import modeling_qwen3_moe
from transformers.models.qwen3_vl import modeling_qwen3_vl
from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts, \
Qwen3VLMoeTextSparseMoeBlock
else:
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
class GmmFunction_vl(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list):
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
fwd_output = torch_npu.npu_grouped_matmul([x], [weight], bias=None, group_list=group_list,
split_item=2, group_type=0, group_list_type=1)[0]
return fwd_output
@staticmethod
def backward(ctx, grad_output):
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
weight = torch.transpose(weight, 1, 2)
grad_input = torch_npu.npu_grouped_matmul([grad_output], [weight], bias=None, group_list=group_list,
split_item=2, group_type=0, group_list_type=1)[0]
grad_weight = torch_npu.npu_grouped_matmul([input_tensor.T], [grad_output], bias=None, group_list=group_list,
split_item=3, group_type=2, group_list_type=1)[0]
return grad_input, grad_weight, None
def npu_group_gemm(x, weight, group_list):
output = GmmFunction_vl.apply(x, weight, group_list)
return output
class Qwen3VLMoeTextExperts_npu(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.intermediate_size = config.moe_intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size)))
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
"""
When training it is more efficient to just loop over the experts and compute the output for each expert
as otherwise the memory would explode.
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
router_indices (torch.Tensor): (batch_size * token_num, top_k)
Returns:
torch.Tensor
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
if self.training:
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states,
router_indices.to(torch.int32))
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, self.gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = npu_group_gemm(intermediate_activations, self.down_proj, tokens_per_expert)
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
next_states = next_states.view(batch_size, -1, self.hidden_size)
else:
hidden_states = hidden_states.repeat(self.num_experts, 1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, self.gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1)
next_states = torch.bmm((up * self.act_fn(gate)), self.down_proj)
next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
next_states = (
next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
)
next_states = next_states.sum(dim=0)
return next_states
class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = Qwen3VLMoeTextExperts_npu(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(router_logits.dtype)
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
if not self.training:
routing_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
routed_out = self.experts(hidden_states, routing_weights, router_indices)
return routed_out
def rms_norm_forward(self, x):
if x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]
def apply_rotary_pos_emb_qwen3_npu(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)
def silu_forward(self, hidden_state):
"""NPU optimized silu"""
gate_up = torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1)
return self.down_proj(torch_npu.npu_swiglu(gate_up, dim=-1))
def apply_npu_plugin():
"""
Apply NPU optimization patches in correct order.
Patches must be applied in specific sequence to ensure proper functionality.
"""
vllm_sleep_level.VLLM_SLEEP_LEVEL = 1
from tensordict.base import TensorDictBase
def _sync_all_patch(self):
from torch._utils import _get_available_device_type, _get_device_module
try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_compiling
device_type = _get_available_device_type()
if device_type is None:
return
if device_type == "cuda":
if not is_compiling() and torch.cuda.is_initialized():
torch.cuda.synchronize()
else:
device_module = _get_device_module(device_type)
device_module.synchronize()
TensorDictBase._sync_all = _sync_all_patch
if get_version("transformers") > "4.57.1":
modeling_qwen3_vl_moe.Qwen3VLMoeTextSparseMoeBlock = Qwen3VLMoeTextSparseMoeBlock
modeling_qwen3_vl_moe.Qwen3VLMoeTextRMSNorm.forward = rms_norm_forward
modeling_qwen3_vl_moe.apply_rotary_pos_emb = apply_rotary_pos_emb_qwen3_npu
modeling_qwen3_vl.Qwen3VLTextRMSNorm.forward = rms_norm_forward
modeling_qwen3_vl.Qwen3VLTextMLP.forward = silu_forward