from typing import Optional, Union
import torch
import torch_npu
import torch.nn as nn
from torch.distributed.tensor import DTensor
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring
from transformers.utils.generic import OutputRecorder, check_model_inputs
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from megatron.training import get_args
from megatron.core import mpu
from mindspeed_mm.models.common.gmm import npu_group_gemm
from mindspeed_mm.models.common.communications import split_forward_gather_backward_with_cp
from mindspeed_mm.models.common.fused_moe import fused_ep_forward
from .output import Qwen3VLMoeCausalLMOutputWithPast
from .modules import (
Qwen3VLTextAttention,
Qwen3VLTextRMSNorm,
Qwen3VLTextMLP,
Qwen3VLTextRotaryEmbedding,
Qwen3VLLMHead,
Qwen3VLEmptyModule
)
from .modeling_qwen3_vl import (
Qwen3VLTextModel,
Qwen3VLModel,
Qwen3VLVisionModel,
Qwen3VLForConditionalGeneration
)
class Qwen3VLMoeTextExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.intermediate_size = config.moe_intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts * self.hidden_size, 2 * self.expert_dim))
self.down_proj = nn.Parameter(torch.empty((self.num_experts * self.expert_dim, self.hidden_size)))
self.act_fn = ACT2FN[config.hidden_act]
def _view_experts_weight(self):
gate_up_proj = self.gate_up_proj.to_local() if isinstance(self.gate_up_proj, DTensor) else self.gate_up_proj
gate_up_proj = gate_up_proj.view(-1, self.hidden_size, 2 * self.expert_dim)
down_proj = self.down_proj.to_local() if isinstance(self.down_proj, DTensor) else self.down_proj
down_proj = down_proj.view(-1, self.expert_dim, self.hidden_size)
return gate_up_proj, down_proj
def forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor, router_logits: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
When training it is more efficient to just loop over the experts and compute the output for each expert
as otherwise the memory would explode.
For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs.
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
routing_weights (torch.Tensor): (batch_size * token_num, num_experts)
router_indices (torch.Tensor): (batch_size * token_num, top_k)
Returns:
torch.Tensor
"""
gate_up_proj, down_proj = self._view_experts_weight()
if router_logits is not None:
routing_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
if self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx[0]])
current_state = hidden_states[token_idx]
gate_up = current_state @ gate_up_proj[expert_idx]
gate, up = gate_up.chunk(2, dim=-1)
gated_output = up * self.act_fn(gate)
out = gated_output @ down_proj[expert_idx]
weighted_output = out[0] * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
else:
hidden_states = hidden_states.repeat(self.num_experts, 1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, gate_up_proj)
gate, up = gate_up.chunk(2, dim=-1)
next_states = torch.bmm((up * self.act_fn(gate)), down_proj)
next_states = next_states.reshape(self.num_experts, batch_size, -1, self.hidden_size)
next_states = (
next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
)
next_states = next_states.sum(dim=0)
return next_states
@staticmethod
def ep_forward(ep_group, self, hidden_states, routing_weights, router_indices, *args, **kwargs):
raise NotImplementedError("must set `use_npu_fused_moe=True` when enable expert parallelism.")
class Qwen3VLNpuFusedMoETextExperts(Qwen3VLMoeTextExperts):
"""NPU fusd Moe"""
def forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor, router_logits: Optional[torch.Tensor] = None,
) -> torch.Tensor:
gate_up_proj, down_proj = self._view_experts_weight()
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
permuted_hidden_states, row_ids_map = torch_npu.npu_moe_token_permute(hidden_states, router_indices.to(torch.int32))
tokens_per_expert = torch.histc(router_indices, bins=self.num_experts, min=0, max=self.num_experts)
intermediate_hidden_states = npu_group_gemm(permuted_hidden_states, gate_up_proj, tokens_per_expert)
intermediate_activations = torch_npu.npu_swiglu(intermediate_hidden_states, dim=-1)
output = npu_group_gemm(intermediate_activations, down_proj, tokens_per_expert)
next_states = torch_npu.npu_moe_token_unpermute(output, row_ids_map, probs=routing_weights)
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
@staticmethod
def ep_forward(ep_group, self, hidden_states, routing_weights, router_indices, *args, **kwargs):
gate_up_proj, down_proj = self._view_experts_weight()
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
hidden_states = fused_ep_forward(
self.num_experts,
routing_weights,
router_indices,
hidden_states,
fc1_weight=gate_up_proj,
fc2_weight=down_proj,
ep_group=ep_group
)
hidden_states = hidden_states.view(batch_size, -1, self.hidden_size)
return hidden_states
class Qwen3VLMoeTextSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.use_npu_fused_moe = getattr(get_args().mm.model.text_decoder, "use_npu_fused_moe", True)
if self.use_npu_fused_moe:
self.experts = Qwen3VLNpuFusedMoETextExperts(config)
else:
self.experts = Qwen3VLMoeTextExperts(config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
router_logits = self.gate(hidden_states)
routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
hidden_states = hidden_states.reshape(batch_size, -1, self.hidden_size)
routed_out = self.experts(hidden_states, routing_weights, router_indices, router_logits)
return routed_out
class Qwen3VLMoeTextDecoderLayer(nn.Module):
def __init__(self, config: Qwen3VLMoeTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.self_attn = Qwen3VLTextAttention(config, layer_idx)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3VLMoeTextSparseMoeBlock(config)
else:
self.mlp = Qwen3VLTextMLP(config, intermediate_size=config.intermediate_size)
self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hidden_size = config.hidden_size
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
if self.config.synchronize_per_layer:
torch.npu.current_stream().synchronize()
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Qwen3VLMoePreTrainedModel(PreTrainedModel):
config: Qwen3VLMoeConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLVisionBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
"hidden_states": Qwen3VLMoeTextDecoderLayer,
"attentions": Qwen3VLTextAttention,
}
def _init_weights(self, module):
"""Initialize the weights."""
super()._init_weights(module)
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range
else:
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
if isinstance(module, Qwen3VLMoeTextExperts):
module.gate_up_proj.data.normal_(mean=0.0, std=std)
module.down_proj.data.normal_(mean=0.0, std=std)
@auto_docstring(
custom_intro=(
"Text part of Qwen3VLMoe, "
"not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
)
)
class Qwen3VLMoeTextModel(Qwen3VLMoePreTrainedModel, Qwen3VLTextModel):
config: Qwen3VLMoeTextConfig
_no_split_modules = ["Qwen3VLMoeTextDecoderLayer"]
def __init__(self, config: Qwen3VLMoeTextConfig):
Qwen3VLMoePreTrainedModel.__init__(self, config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm_hook_module = Qwen3VLEmptyModule()
self.layers = nn.ModuleList(
[Qwen3VLMoeTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
if config.activation_offload:
self.swap_stream = torch.npu.Stream()
self.post_init()
@auto_docstring
class Qwen3VLMoeModel(Qwen3VLMoePreTrainedModel, Qwen3VLModel):
base_model_prefix = ""
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
config: Qwen3VLMoeConfig
_no_split_modules = ["Qwen3VLMoeTextDecoderLayer", "Qwen3VLVisionBlock"]
def __init__(self, config):
Qwen3VLMoePreTrainedModel.__init__(self, config)
self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
self.language_model = Qwen3VLMoeTextModel._from_config(config.text_config)
self.rope_deltas = None
self.post_init()
def load_balancing_loss_func(
gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Args:
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
Returns:
The auxiliary loss.
"""
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
_, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
router_prob_per_expert = torch.mean(routing_weights, dim=0)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(-1, top_k, num_experts)
.to(compute_device)
)
sum_expert_attention_mask = torch.sum(expert_attention_mask, dim=0)
torch.distributed.all_reduce(
sum_expert_attention_mask,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_context_parallel_group()
)
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / sum_expert_attention_mask
torch.distributed.all_reduce(
tokens_per_expert,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_context_parallel_group()
)
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(-1, num_experts)
.to(compute_device)
)
sum_router_per_expert_attention_mask = torch.sum(router_per_expert_attention_mask, dim=0)
torch.distributed.all_reduce(
sum_router_per_expert_attention_mask,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_context_parallel_group()
)
router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / sum_router_per_expert_attention_mask
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
return overall_loss * num_experts
class Qwen3VLMoeForConditionalGeneration(Qwen3VLMoePreTrainedModel, Qwen3VLForConditionalGeneration):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = ["lm_head.weight"]
accepts_loss_kwargs = False
config: Qwen3VLMoeConfig
def __init__(self, config):
Qwen3VLMoePreTrainedModel.__init__(self, config)
GenerationMixin.__init__(self)
self.model = Qwen3VLMoeModel(config)
self.lm_head = Qwen3VLLMHead(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
loss_ctx: Optional[callable] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen3VLMoeCausalLMOutputWithPast]:
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if loss_ctx:
logits, loss = self.lm_head(hidden_states[:, slice_indices, :], loss_ctx=loss_ctx)
else:
logits, loss = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
aux_loss = None
if kwargs.get("output_router_logits", False):
if attention_mask is not None:
attention_mask = split_forward_gather_backward_with_cp(attention_mask, dim=1)
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.config.text_config.num_experts,
self.config.text_config.num_experts_per_tok,
attention_mask,
)
return Qwen3VLMoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
rope_deltas=outputs.rope_deltas,
)