from collections.abc import Callable
from typing import Optional
import torch
import torch_npu
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor import DTensor
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.processing_utils import Unpack
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig, Qwen3VLVisionConfig
from megatron.core import mpu
from megatron.training import get_args
from mindspeed.core.context_parallel.model_parallel_utils import (
get_context_parallel_group_for_hybrid_ulysses,
get_context_parallel_group_for_hybrid_ring,
get_context_parallel_for_hybrid_ring_world_size,
get_context_parallel_for_hybrid_ulysses_world_size,
get_context_parallel_for_hybrid_ring_global_ranks,
get_context_parallel_for_hybrid_ring_rank
)
from mindspeed.core.context_parallel.ring_context_parallel.ring_context_parallel import ringattn_context_parallel_tnd_general, ringattn_context_parallel
from mindspeed.utils import get_actual_seq_len
from mindspeed_mm.models.common.communications import cal_split_sizes, cal_split_sizes_multi, split_forward_gather_backward
from mindspeed_mm.utils.utils import get_packed_seq_params, get_packed_seq_len
from ..cp_utils import get_seq_len, gather_seq_scatter_heads_qkv, gather_heads_scatter_seq, gather_visual_seqs_with_cp
from ..attention_utils import ALL_ATTENTION_FUNCTIONS, pad_out
class Qwen3VLEmptyModule(nn.Module):
"""
This class does not implement any functionality. It serves solely as a placeholder
to provide a registration point for attaching FSDP2 hooks to all normalization (e.g., LayerNorm, RMSNorm)
and gate-related parameters when the `align_fsdp_param_groups` feature is enabled.
Its purpose is structural: to ensure these specific parameters are correctly identified
and included in FSDP2's parameter grouping and communication logic, without participating
in forward/backward computation or maintaining any internal state.
"""
def __init__(self):
super().__init__()
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return hidden_state
class Qwen3VLVisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
class Qwen3VLVisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class Qwen3VLVisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen3VLVisionPatchMerger(nn.Module):
def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.spatial_merge_size = config.spatial_merge_size
self.use_postshuffle_norm = use_postshuffle_norm
self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.act_fn = nn.GELU()
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if mpu.get_context_parallel_world_size() > 1:
if self.use_postshuffle_norm:
x = gather_visual_seqs_with_cp(x, dim=0)
x = x.view(-1, self.hidden_size)
split_sizes = cal_split_sizes(x.shape[0], mpu.get_context_parallel_world_size())
x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=0, grad_scale="down", split_sizes=split_sizes)
x = self.norm(x)
else:
x = self.norm(x)
x = gather_visual_seqs_with_cp(x, dim=0)
x = x.view(-1, self.hidden_size)
split_sizes = cal_split_sizes(x.shape[0], mpu.get_context_parallel_world_size())
x = split_forward_gather_backward(x, mpu.get_context_parallel_group(), dim=0, grad_scale="down", split_sizes=split_sizes)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
else:
x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
q_embed = q_embed.squeeze(0)
k_embed = k_embed.squeeze(0)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int, layout: str) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). Adapt to different attention layouts:
insert expansion dim after num_key_value_heads, merge to num_attention_heads, keep other dims unchanged.
"""
if n_rep == 1:
return hidden_states
if layout == "BNSD":
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
elif layout == "BSND":
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
elif layout == "TND":
token, num_key_value_heads, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :].expand(token, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(token, num_key_value_heads * n_rep, head_dim)
else:
raise NotImplementedError(
f"Unsupported Attention layout: {layout}, "
"Qwen3OmniMoeThinkerTextAttention only support ['BNSD', 'BSND', 'TND'] now.")
def do_vit_ring_context_parallel(q, k, v, head_num, softmax_scale, attn_mask=None, dropout_p=0., pse=None, pse_type=None, shapes=None):
args = get_args()
in_hybrid_mode = get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None
if in_hybrid_mode:
cp_group = get_context_parallel_group_for_hybrid_ring()
cp_size = get_context_parallel_for_hybrid_ring_world_size()
rank = get_context_parallel_for_hybrid_ring_rank()
cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
else:
cp_group = mpu.get_context_parallel_group()
cp_size = mpu.get_context_parallel_world_size()
rank = mpu.get_context_parallel_rank()
cp_global_ranks = mpu.get_context_parallel_global_ranks()
cp_para = dict()
cp_para['causal'] = False
cp_para['cp_group'] = cp_group
cp_para['cp_size'] = cp_size
cp_para['rank'] = rank
cp_para['cp_global_ranks'] = cp_global_ranks
cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap() \
if args.use_cp_send_recv_overlap else None
cp_para['pse'] = pse
cp_para['pse_type'] = pse_type
output = ringattn_context_parallel_tnd_general(q, k, v, head_num, cp_para, softmax_scale, attn_mask, dropout_p, shapes=shapes)
return output
class Qwen3VLVisionAttention(nn.Module):
def __init__(self, config: Qwen3VLVisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
total_visual_seqlen = int(cu_seqlens[-1])
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
layout = self.config.attn_layout.upper()
seq_dim, head_dim = None, None
attention_kwargs = {"scale": self.scaling, "dropout": self.attention_dropout, "is_causal": self.is_causal, "attention_mask": None}
if self.config._attn_implementation == "flash_attention_2" and layout == "TND":
seq_dim, head_dim = 0, 1
attention_kwargs["actual_seq_qlen"] = cu_seqlens
attention_kwargs["actual_seq_kvlen"] = cu_seqlens
attention_kwargs["layout"] = "TND"
elif self.config._attn_implementation in ["eager", "sdpa", "flash_attention_2"] and layout == "BNSD":
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
seq_dim, head_dim = 2, 1
attention_kwargs["layout"] = "BNSD"
else:
raise NotImplementedError(
f"Unsupported Attention: {self.config._attn_implementation}, or layout: {layout}"
"Qwen3VLTextAttention only support ['eager', 'sdpa', 'flash_attention_2'], layout TND and BNSD")
if mpu.get_context_parallel_world_size() > 1:
megatron_args = get_args()
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=total_visual_seqlen
)
elif megatron_args.context_parallel_algo == "megatron_cp_algo":
if layout != "TND":
raise ValueError(f"Vision Attention only support layout `TND` when using Ring Attention.")
all_split_sizes_tensor = cal_split_sizes_multi(get_seq_len("per_visual"), mpu.get_context_parallel_world_size())
attn_output = do_vit_ring_context_parallel(
query_states,
key_states,
value_states,
self.num_heads,
self.scaling,
attn_mask=None,
dropout_p=0.,
pse=None,
pse_type=None,
shapes=all_split_sizes_tensor
)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
elif megatron_args.context_parallel_algo == "hybrid_cp_algo":
if layout != "TND":
raise ValueError(f"Vision Attention only support layout `TND` when using Hybrid Attention.")
ulysses_process_group = get_context_parallel_group_for_hybrid_ulysses()
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(query_states, key_states, value_states, seq_dim=0, head_dim=1, gather_size=total_visual_seqlen, group=ulysses_process_group)
all_split_sizes_tensor = cal_split_sizes_multi(get_seq_len("per_visual"), get_context_parallel_for_hybrid_ring_world_size())
attn_output = do_vit_ring_context_parallel(
query_states,
key_states,
value_states,
self.num_heads // get_context_parallel_for_hybrid_ulysses_world_size(),
self.scaling,
attn_mask=None,
dropout_p=0.,
pse=None,
pse_type=None,
shapes=all_split_sizes_tensor
)
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=0, head_dim=1, gather_size=self.num_heads, group=get_context_parallel_group_for_hybrid_ulysses())
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`,`megatron_cp_algo`,`hybrid_cp_algo`, but got {megatron_args.context_parallel_algo}")
if layout == "TND":
attn_output = attention_interface(
query_states,
key_states,
value_states,
**attention_kwargs
)
else:
lengths = [cu_seqlens[0]] + [post_len - seqlen for seqlen, post_len in zip(cu_seqlens, cu_seqlens[1:])]
splits = [
torch.split(tensor, lengths, dim=seq_dim)
for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
q,
k,
v,
**attention_kwargs,
)
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=seq_dim)
if mpu.get_context_parallel_world_size() > 1:
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.num_heads
)
if layout == "BNSD":
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Qwen3VLVisionBlock(nn.Module):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.config = config
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.attn = Qwen3VLVisionAttention(config=config)
self.mlp = Qwen3VLVisionMLP(config=config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
if self.config.synchronize_per_layer:
torch.npu.synchronize()
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class Qwen3VLTextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: Qwen3VLTextConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
if hasattr(self.config, "rope_parameters"):
self.rope_type = self.config.rope_parameters["rope_type"]
elif hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
self.rope_type = self.config.rope_scaling["rope_type"]
else:
self.rope_type = "default"
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
@staticmethod
def compute_default_rope_parameters(
config: Optional[Qwen3VLTextConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
if hasattr(config, "rope_parameters"):
base = config.rope_parameters["rope_theta"]
else:
base = config.rope_theta
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
attention_factor = 1.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THTHWHTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0]
for dim, offset in enumerate((1, 2), start=1):
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Qwen3VLTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen3VLTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
def do_llm_ring_context_parallel(q, k, v, head_num, softmax_scale, attn_mask=None, dropout_p=0., pse=None, pse_type=None, shapes=None, packed_seq_params=None, layout="SBH"):
args = get_args()
in_hybrid_mode = get_context_parallel_group_for_hybrid_ring(check_initialized=False) is not None
if in_hybrid_mode:
cp_group = get_context_parallel_group_for_hybrid_ring()
cp_size = get_context_parallel_for_hybrid_ring_world_size()
rank = get_context_parallel_for_hybrid_ring_rank()
cp_global_ranks = get_context_parallel_for_hybrid_ring_global_ranks()
else:
cp_group = mpu.get_context_parallel_group()
cp_size = mpu.get_context_parallel_world_size()
rank = mpu.get_context_parallel_rank()
cp_global_ranks = mpu.get_context_parallel_global_ranks()
cp_para = dict()
cp_para['causal'] = True
cp_para['cp_group'] = cp_group
cp_para['cp_size'] = cp_size
cp_para['rank'] = rank
cp_para['cp_global_ranks'] = cp_global_ranks
cp_para['cp_group_for_send_recv_overlap'] = mpu.get_context_parallel_group_for_send_recv_overlap() \
if args.use_cp_send_recv_overlap else None
cp_para['pse'] = pse
cp_para['pse_type'] = pse_type
cp_para['megatron_cp_in_bnsd'] = args.megatron_cp_in_bnsd
if layout == "TND":
actual_seq_len = get_actual_seq_len()
packed_seq_params, shapes = get_packed_seq_params(actual_seq_len, cp_size=cp_size)
attn_output = ringattn_context_parallel(q, k, v, head_num, cp_para, softmax_scale, attn_mask, dropout_p, packed_seq_params=packed_seq_params, shapes=shapes)
return attn_output
elif layout == "BNSD":
D = q.shape[-1]
q = rearrange(q, "b s n d -> s b (n d)").contiguous()
k = rearrange(k, "b s n d -> s b (n d)").contiguous()
v = rearrange(v, "b s n d -> s b (n d)").contiguous()
attn_output = ringattn_context_parallel(q, k, v, head_num, cp_para, softmax_scale, attn_mask, dropout_p, packed_seq_params=None, shapes=shapes)
attn_output = rearrange(attn_output, "s b (n d) -> b s n d", d=D).contiguous()
return attn_output
else:
raise NotImplementedError("LLM Ring CP only support TND and BNSD now")
class Qwen3VLTextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3VLTextRMSNorm(
self.head_dim, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seqlen = hidden_states.shape[:-1]
hidden_shape = (batch_size, seqlen, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape))
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape))
value_states = self.v_proj(hidden_states).view(hidden_shape)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
layout = self.config.attn_layout.upper()
dropout = 0.0 if not self.training else self.attention_dropout
attention_kwargs = {
"scale": self.scaling,
"dropout": dropout,
"is_causal": self.is_causal,
"layout": layout,
"enable_gqa": True
}
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
total_seq_len = get_seq_len("total")
if mpu.get_context_parallel_world_size() > 1:
megatron_args = get_args()
seq_dim, head_dim = 1, 2
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
if mpu.get_context_parallel_world_size() > self.config.num_key_value_heads:
key_states = repeat_kv(key_states, self.num_key_value_groups, "BSND")
value_states = repeat_kv(value_states, self.num_key_value_groups, "BSND")
attention_kwargs["enable_gqa"] = False
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=total_seq_len
)
elif megatron_args.context_parallel_algo == "megatron_cp_algo":
if layout not in ["BNSD", "TND"]:
raise ValueError(f"TextAttention only support layout `BNSD` and `TND` when using Ring Attention.")
if layout == "TND":
query_states = query_states.view(-1, *query_states.shape[2:])
key_states = key_states.view(-1, *key_states.shape[2:])
value_states = value_states.view(-1, *value_states.shape[2:])
attn_output = do_llm_ring_context_parallel(
query_states,
key_states,
value_states,
self.config.num_attention_heads,
softmax_scale=self.scaling,
attn_mask=None,
dropout_p=0.,
pse=None,
pse_type=None,
shapes=None,
layout=layout
)
attn_output = attn_output.reshape(batch_size, seqlen, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
elif megatron_args.context_parallel_algo == "hybrid_cp_algo":
if layout not in ["BNSD", "TND"]:
raise ValueError(f"TextAttention only support layout `BNSD` and `TND` when using Ring Attention.")
if layout == "TND" or get_context_parallel_for_hybrid_ulysses_world_size() > self.config.num_key_value_heads:
key_states = repeat_kv(key_states, self.num_key_value_groups, layout="BSND")
value_states = repeat_kv(value_states, self.num_key_value_groups, layout="BSND")
actual_seq_len = get_actual_seq_len()
if actual_seq_len is not None:
total_seq_len = get_packed_seq_len(actual_seq_len, get_context_parallel_for_hybrid_ring_world_size())
else:
total_seq_len = get_seq_len("total")
seq_len_per_ring = total_seq_len // get_context_parallel_for_hybrid_ring_world_size()
query_states, key_states, value_states = gather_seq_scatter_heads_qkv(
query_states,
key_states,
value_states,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=seq_len_per_ring,
group=get_context_parallel_group_for_hybrid_ulysses()
)
if layout == "TND":
query_states = query_states.view(-1, *query_states.shape[2:])
key_states = key_states.view(-1, *key_states.shape[2:])
value_states = value_states.view(-1, *value_states.shape[2:])
attn_output = do_llm_ring_context_parallel(
query_states,
key_states,
value_states,
self.config.num_attention_heads // get_context_parallel_for_hybrid_ulysses_world_size(),
softmax_scale=self.scaling,
attn_mask=None,
dropout_p=0.,
pse=None,
pse_type=None,
shapes=None,
layout=layout
)
if layout == "TND":
attn_output = rearrange(attn_output, '(b s) n d -> b s n d', b=batch_size).contiguous()
else:
attn_output = rearrange(attn_output, "s b (n d) -> b s n d", d=self.head_dim).contiguous()
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.config.num_attention_heads,
group=get_context_parallel_group_for_hybrid_ulysses()
)
attn_output = attn_output.reshape(batch_size, seqlen, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`,`megatron_cp_algo`,`hybrid_cp_algo`, but got {megatron_args.context_parallel_algo}")
if layout == "BNSD":
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attention_kwargs["attention_mask"] = attention_mask
elif layout == "BSND":
attention_kwargs["attention_mask"] = attention_mask
elif layout == "TND":
attention_kwargs["actual_seq_qlen"] = kwargs["cu_seqlens"]
attention_kwargs["actual_seq_kvlen"] = kwargs["cu_seqlens"]
indices = kwargs["indices"]
query_states = query_states.view(-1, *query_states.shape[2:])[indices]
key_states = key_states.view(-1, *key_states.shape[2:])[indices]
value_states = value_states.view(-1, *value_states.shape[2:])[indices]
else:
raise NotImplementedError(
f"Unsupported Attention layout: {layout}, "
"Qwen3VLTextAttention only support ['BNSD', 'BSND', 'TND'] now.")
attn_output = attention_interface(
query_states,
key_states,
value_states,
**attention_kwargs,
)
if layout == "BNSD":
attn_output = attn_output.transpose(1, 2)
if layout == "TND":
attn_output = pad_out(attn_output, indices, batch_size, total_seq_len)
attn_output = attn_output.view(batch_size, total_seq_len, *attn_output.shape[1:])
if mpu.get_context_parallel_world_size() > 1:
attn_output = gather_heads_scatter_seq(
attn_output,
seq_dim=seq_dim,
head_dim=head_dim,
gather_size=self.num_heads
)
attn_output = attn_output.reshape(batch_size, seqlen, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class Qwen3VLTextMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen3VLLMHead(nn.Linear):
def forward(self, hidden_states: torch.Tensor, loss_ctx: callable = None):
if isinstance(self.weight, DTensor):
w = self.weight.to_local()
if self.bias is not None:
if not isinstance(self.bias, DTensor):
raise TypeError(
f"Expected bias to be a DTensor when weight is a DTensor, "
f"but got bias of type {type(self.bias)}."
)
b = self.bias.to_local()
else:
b = None
else:
w = self.weight
b = self.bias
if loss_ctx is None:
logits = F.linear(hidden_states, w, b)
return logits, None
else:
return None, loss_ctx(hidden_states, w, b)