from collections.abc import Callable
from typing import Optional
import torch
import torch_npu
from torch import nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.processing_utils import Unpack
from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeVisionEncoderConfig
from megatron.core import mpu
from megatron.training import get_args
from ..attention_utils import ALL_ATTENTION_FUNCTIONS, pad_out
from ..cp_utils import get_seq_len, gather_seq_scatter_heads_qkv, gather_heads_scatter_seq
class Qwen3OmniMoeAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.num_heads = config.encoder_attention_heads
self.dropout = config.attention_dropout
self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_groups = 1
self.config = config
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
seq_length = hidden_states.size(0)
total_audio_seqlen = int(cu_seqlens[-1])
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
attention_kwargs = {
"scale": self.scaling,
"dropout": 0.0 if not self.training else self.attention_dropout,
"is_causal": self.is_causal,
"layout": self.config.attn_layout,
}
seq_dim, head_dim = 0, 1
if self.config._attn_implementation == "flash_attention_2" and self.config.attn_layout == "TND":
attention_kwargs["actual_seq_qlen"] = cu_seqlens
attention_kwargs["actual_seq_kvlen"] = cu_seqlens
elif (
self.config._attn_implementation in ["eager", "sdpa", "flash_attention_2"]
and self.config.attn_layout == "BNSD"
):
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_kwargs["attention_mask"] = attention_mask
seq_dim, head_dim = 2, 1
else:
raise NotImplementedError(
f"Unsupported Attention: {self.config._attn_implementation}, or layout: {self.config.attn_layout}"
"Qwen3OmniMoeAudioAttention only support ['eager', 'sdpa', 'flash_attention_2'], layout TND and BNSD")
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if mpu.get_context_parallel_world_size() > 1 and mpu.get_context_parallel_world_size() < total_audio_seqlen:
megatron_args = get_args()
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=total_audio_seqlen
)
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`, but got {megatron_args.context_parallel_algo}")
attn_output = attention_interface(
query_states,
key_states,
value_states,
**attention_kwargs,
)
if mpu.get_context_parallel_world_size() > 1 and mpu.get_context_parallel_world_size() < total_audio_seqlen:
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.num_heads
)
if self.config.attn_layout == "BNSD":
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
q_embed = q_embed.squeeze(0)
k_embed = k_embed.squeeze(0)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class Qwen3OmniMoeVisionAttention(nn.Module):
def __init__(self, config: Qwen3OmniMoeVisionEncoderConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
total_visual_seqlen = int(cu_seqlens[-1])
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
seq_dim, head_dim = None, None
attention_kwargs = {
"scale": self.scaling,
"dropout": 0.0 if not self.training else self.attention_dropout,
"is_causal": self.is_causal,
"attention_mask": None,
}
if self.config._attn_implementation == "flash_attention_2" and self.config.attn_layout == "TND":
seq_dim, head_dim = 0, 1
attention_kwargs["actual_seq_qlen"] = cu_seqlens
attention_kwargs["actual_seq_kvlen"] = cu_seqlens
attention_kwargs["layout"] = "TND"
elif self.config._attn_implementation in ["eager", "sdpa", "flash_attention_2"] and self.config.attn_layout == "BNSD":
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
seq_dim, head_dim = 2, 1
attention_kwargs["layout"] = "BNSD"
else:
raise NotImplementedError(
f"Unsupported Attention: {self.config._attn_implementation}, or layout: {self.config.attn_layout}"
"Qwen3OmniMoeVisionAttention only support ['eager', 'sdpa', 'flash_attention_2'], layout TND and BNSD")
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if mpu.get_context_parallel_world_size() > 1:
megatron_args = get_args()
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=total_visual_seqlen
)
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`, but got {megatron_args.context_parallel_algo}")
if self.config.attn_layout == "TND":
attn_output = attention_interface(
query_states,
key_states,
value_states,
**attention_kwargs
)
else:
lengths = [cu_seqlens[0]] + [post_len - seqlen for seqlen, post_len in zip(cu_seqlens, cu_seqlens[1:])]
splits = [
torch.split(tensor, lengths, dim=seq_dim)
for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
q,
k,
v,
**attention_kwargs,
)
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=seq_dim)
if mpu.get_context_parallel_world_size() > 1:
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.num_heads
)
if self.config.attn_layout == "BNSD":
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Qwen3OmniMoeThinkerTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen3OmniMoeThinkerTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int, layout: str) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Adapt to different attention layouts:
insert expansion dim after num_key_value_heads, merge to num_attention_heads, keep other dims unchanged.
"""
if n_rep == 1:
return hidden_states
if layout == "BNSD":
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
elif layout == "BSND":
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
elif layout == "TND":
token, num_key_value_heads, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :].expand(token, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(token, num_key_value_heads * n_rep, head_dim)
else:
raise NotImplementedError(
f"Unsupported Attention layout: {layout}, "
"Qwen3OmniMoeThinkerTextAttention only support ['BNSD', 'BSND', 'TND'] now.")
class Qwen3OmniMoeThinkerTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
)
self.k_norm = Qwen3OmniMoeThinkerTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
)
self.sliding_window = None
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seqlen = hidden_states.shape[:-1]
hidden_shape = (batch_size, seqlen, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
value_states = self.v_proj(hidden_states).view(hidden_shape)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
dropout = 0.0 if not self.training else self.attention_dropout
attention_kwargs = {
"scale": self.scaling,
"dropout": dropout,
"is_causal": self.is_causal,
"layout": self.config.attn_layout,
"enable_gqa": True
}
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
total_seq_len = get_seq_len("total")
if mpu.get_context_parallel_world_size() > 1:
megatron_args = get_args()
seq_dim, head_dim = 1, 2
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
if mpu.get_context_parallel_world_size() > self.config.num_key_value_heads:
key_states = repeat_kv(key_states, self.num_key_value_groups, "BSND")
value_states = repeat_kv(value_states, self.num_key_value_groups, "BSND")
attention_kwargs["enable_gqa"] = False
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=total_seq_len
)
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`, but got {megatron_args.context_parallel_algo}")
if self.config.attn_layout == "BNSD":
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attention_kwargs["attention_mask"] = attention_mask
elif self.config.attn_layout == "BSND":
attention_kwargs["attention_mask"] = attention_mask
elif self.config.attn_layout == "TND":
attention_kwargs["actual_seq_qlen"] = kwargs["cu_seqlens"]
attention_kwargs["actual_seq_kvlen"] = kwargs["cu_seqlens"]
indices = kwargs["indices"]
query_states = query_states.view(-1, *query_states.shape[2:])[indices]
key_states = key_states.view(-1, *key_states.shape[2:])[indices]
value_states = value_states.view(-1, *value_states.shape[2:])[indices]
else:
raise NotImplementedError(
f"Unsupported Attention layout: {self.config.attn_layout}, "
"Qwen3OmniMoeThinkerTextAttention only support ['BNSD', 'BSND', 'TND'] now.")
attn_output = attention_interface(
query_states,
key_states,
value_states,
**attention_kwargs,
)
if self.config.attn_layout == "BNSD":
attn_output = attn_output.transpose(1, 2)
if self.config.attn_layout == "TND":
attn_output = pad_out(attn_output, indices, batch_size, total_seq_len)
attn_output = attn_output.view(batch_size, total_seq_len, *attn_output.shape[1:])
if mpu.get_context_parallel_world_size() > 1:
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.config.num_attention_heads
)
attn_output = attn_output.reshape(batch_size, seqlen, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen3OmniLMHead(nn.Linear):
def forward(self, hidden_states: torch.Tensor, loss_ctx: callable = None):
if isinstance(self.weight, DTensor):
w = self.weight.to_local()
if self.bias is not None:
if not isinstance(self.bias, DTensor):
raise TypeError(
f"Expected bias to be a DTensor when weight is a DTensor, "
f"but got bias of type {type(self.bias)}."
)
b = self.bias.to_local()
else:
b = None
else:
w = self.weight
b = self.bias
if loss_ctx is None:
logits = F.linear(hidden_states, w, b)
return logits, None
else:
return None, loss_ctx(hidden_states, w, b)