from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Optional, Union, List
import torch
import torch.nn.functional as F
from torch import nn
from transformers import initialization as init
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernelized_func
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPooling,
CausalLMOutputWithPast,
ModelOutput,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from transformers.utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
from transformers.utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
from transformers.utils.output_capturing import capture_outputs
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5Config, Qwen3_5TextConfig, Qwen3_5VisionConfig
from mindspeed.fsdp.utils.log import print_rank
from mindspeed_mm.fsdp.utils.register import model_register
from mindspeed_mm.fsdp.utils.device import IS_NPU_AVAILABLE
if IS_NPU_AVAILABLE:
import torch_npu
from mindspeed_mm.fsdp.distributed.context_parallel.communication import (
split_forward_gather_backward,
gather_forward_split_backward,
split_forward_gather_backward_with_cp,
gather_forward_split_backward_with_cp,
packed_data_split_forward_gather_backward_with_cp,
packed_data_gather_forward_split_backward_with_cp
)
from mindspeed_mm.fsdp.distributed.parallel_state import get_parallel_state
from mindspeed_mm.fsdp.distributed.context_parallel.utils import cal_split_sizes, cal_split_sizes_multi
from mindspeed_mm.fsdp.distributed.context_parallel.utils import generate_ulysses_cu_seqlen_params
from mindspeed_mm.fsdp.distributed.context_parallel.communication import all_to_all
from mindspeed_mm.fsdp.models.mtp import MultiTokenPredictionBlock
_TOTAL_SEQ_LEN = None
_VISUAL_SEQ_LEN = None
_VISUAL_PER_SEQ_LEN = None
IGNORE_INDEX = -100
def get_seq_len(seq_type: str = None) -> int:
if seq_type == "total":
global _TOTAL_SEQ_LEN
return _TOTAL_SEQ_LEN
elif seq_type == "visual":
global _VISUAL_SEQ_LEN
return _VISUAL_SEQ_LEN
elif seq_type == "per_visual":
global _VISUAL_PER_SEQ_LEN
return _VISUAL_PER_SEQ_LEN
else:
raise ValueError(
f"Invalid sequence type: '{seq_type}'. Expected 'total' or 'visual'."
)
def set_seq_len(seq_type: str = None, seq_len: Optional[Union[int, List[int]]] = None) -> None:
if seq_type == "total":
global _TOTAL_SEQ_LEN
_TOTAL_SEQ_LEN = seq_len
elif seq_type == "visual":
global _VISUAL_SEQ_LEN
_VISUAL_SEQ_LEN = seq_len
elif seq_type == "per_visual":
global _VISUAL_PER_SEQ_LEN
_VISUAL_PER_SEQ_LEN = seq_len
else:
raise ValueError(
f"Invalid sequence type: '{seq_type}'. Expected 'total' or 'visual'."
)
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
if is_flash_linear_attention_available():
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import chunk_gated_delta_rule as fla_chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
else:
fla_chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
FusedRMSNormGated = None
logger = logging.get_logger(__name__)
class Qwen3_5DynamicCache:
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the linear attention
cache (which has a constant shape regardless of seq_len).
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
and `ssm_states` for gated deltanet cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
For linear attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
and `recurrent_states` represents the recurrent state and has a shape of `(batch_size, d_inner, d_state)`.
"""
is_compileable = False
def __init__(self, config: Qwen3_5Config):
super().__init__()
self.layer_types = config.layer_types
self.transformer_layers = [
i for i in range(config.num_hidden_layers) if self.layer_types[i] == "full_attention"
]
self.last_linear_layer = len(self.layer_types) - 1 - self.layer_types[::-1].index("linear_attention")
self.conv_states = [None for _ in range(config.num_hidden_layers)]
self.recurrent_states = [None for _ in range(config.num_hidden_layers)]
self.key_cache = [None for _ in range(config.num_hidden_layers)]
self.value_cache = [None for _ in range(config.num_hidden_layers)]
def __len__(self):
return len(self.layer_types)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.key_cache[layer_idx] is None:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] is not None:
device = self.key_cache[layer_idx].device
beam_idx = beam_idx.to(device)
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx)
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx)
if self.conv_states[layer_idx] is not None:
device = self.conv_states[layer_idx].device
beam_idx = beam_idx.to(device)
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx)
self.recurrent_states[layer_idx] = self.recurrent_states[layer_idx].index_select(0, beam_idx)
def get_seq_length(self, layer_idx: int | None = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None:
return 0
return self.key_cache[layer_idx].shape[-2]
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
"""
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
the given layer at `layer_idx`.
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
"""
kv_offset = 0
query_length = cache_position.shape[0]
past_seen_tokens = self.get_seq_length(layer_idx)
kv_length = query_length + past_seen_tokens
return kv_length, kv_offset
@property
def has_previous_state(self):
"""We have a previous state if the last linear (conv) layer was already updated."""
return self.conv_states[self.last_linear_layer] is not None
class Qwen3_5VisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen3_5TextRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: Qwen3_5TextConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
self.mrope_section = config.rope_parameters.get("mrope_section", [11, 11, 10])
@staticmethod
def compute_default_rope_parameters(
config: Qwen3_5TextConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
base = config.rope_parameters["rope_theta"]
partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, attention_factor
@torch.no_grad()
@dynamic_rope_update
def forward(self, x, position_ids):
if position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with maybe_autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def apply_interleaved_mrope(self, freqs, mrope_section):
"""Apply interleaved MRoPE to 3D rotary embeddings.
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
interleaved [THWTHWTHW...TT], preserving frequency continuity.
args:
x: (3, bs, seq_len, head_dim // 2)
mrope_section: (3,)
returns:
x_t: (bs, seq_len, head_dim // 2)
"""
freqs_t = freqs[0]
for dim, offset in enumerate((1, 2), start=1):
length = mrope_section[dim] * 3
idx = slice(offset, length, 3)
freqs_t[..., idx] = freqs[dim, ..., idx]
return freqs_t
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
if IS_NPU_AVAILABLE:
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]
return hidden_states * F.silu(gate)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens
"""
ps = get_parallel_state()
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
if ps.is_ulysses_enable():
split_sizes = cal_split_sizes(attention_mask.shape[1], world_size=ps.get_ulysses_group_size())
attention_mask = torch.split(attention_mask, split_sizes, dim=1)[ps.get_ulysses_rank()]
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
is_fast_path_available = all(
(causal_conv1d_fn, causal_conv1d_update, fla_chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
)
def torch_causal_conv1d_update(
hidden_states,
conv_state,
weight,
bias=None,
activation=None,
):
_, hidden_size, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(hidden_states_new[:, :, -state_len:])
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
out = F.silu(out[:, :, -seq_len:])
out = out.to(hidden_states.dtype)
return out
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
def torch_recurrent_gated_delta_rule(
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
class Qwen3_5GatedDeltaNet(nn.Module):
def __init__(self, config: Qwen3_5Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=False,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
)
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.norm = (
Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
if FusedRMSNormGated is None
else FusedRMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
self.causal_conv1d_implementation = config.causal_conv1d_implementation
if self.causal_conv1d_implementation == "triton" and IS_NPU_AVAILABLE:
from mindspeed_mm.fsdp.models.qwen3_5.causal_conv1d import causal_conv1d
print_rank(logger.info, "Qwen3_5Moe causal_conv1d use NPU triton ops")
self.causal_conv1d_fn = causal_conv1d
else:
self.causal_conv1d_fn = causal_conv1d_fn
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
self.gdn_implementation = config.gdn_implementation
if self.gdn_implementation == "triton":
if IS_NPU_AVAILABLE:
from mindspeed_mm.fsdp.ops.gdn.chunk_gated_delta_rule import chunk_gated_delta_rule
self.chunk_gated_delta_rule = chunk_gated_delta_rule
elif is_flash_linear_attention_available():
self.chunk_gated_delta_rule = fla_chunk_gated_delta_rule
else:
raise ValueError(
f"gdn_implementation='triton' requires NPU or flash_linear_attention, "
f"but neither is available. Please install the required dependency or use gdn_implementation='eager'."
)
elif self.gdn_implementation == "AscendC":
if IS_NPU_AVAILABLE:
from mindspeed_mm.fsdp.ops.gdn.flash_chunk_gated_delta_rule import chunk_gated_delta_rule
self.chunk_gated_delta_rule = chunk_gated_delta_rule
else:
raise ValueError(
f"gdn_implementation='AscendC' requires NPU, but NPU is not available. "
f"Please use gdn_implementation='eager' or gdn_implementation='triton' instead."
)
elif self.gdn_implementation == "eager":
self.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
else:
raise ValueError(
f"Invalid gdn_implementation='{self.gdn_implementation}'. Must be one of: 'eager', 'triton', 'AscendC'."
)
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of the required library is not installed. Falling back to "
"torch implementation."
)
self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
def _get_local_conv1d_weight(self, ulysses_rank: int, local_key_dim: int, local_value_dim: int) -> torch.Tensor:
w_full = self.conv1d.weight
if w_full.shape[0] != self.key_dim * 2 + self.value_dim:
raise ValueError(
f"conv1d weight dim ({w_full.shape[0]}) must match "
f"(2 * key_dim + value_dim) ({self.key_dim * 2 + self.value_dim})"
)
k_off = ulysses_rank * local_key_dim
v_off = ulysses_rank * local_value_dim
w_q = w_full[k_off: k_off + local_key_dim]
w_k = w_full[self.key_dim + k_off: self.key_dim + k_off + local_key_dim]
w_v = w_full[2 * self.key_dim + v_off: 2 * self.key_dim + v_off + local_value_dim]
return torch.cat([w_q, w_k, w_v], dim=0)
def forward(
self,
hidden_states: torch.Tensor,
cache_params: Qwen3_5DynamicCache | None = None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
ps = get_parallel_state()
if ps.is_ulysses_enable():
ulysses_group = ps.get_ulysses_group()
ulysses_size = ps.get_ulysses_group_size()
ulysses_rank = ps.get_ulysses_rank()
if self.num_k_heads % ulysses_size != 0 or self.num_v_heads % ulysses_size != 0:
raise ValueError(
f"SP size ({ulysses_size}) must divide num_k_heads ({self.num_k_heads}) "
f"and num_v_heads ({self.num_v_heads}) for gated deltanet LASP"
)
local_num_k_heads = self.num_k_heads // ulysses_size
local_num_v_heads = self.num_v_heads // ulysses_size
local_key_dim = self.head_k_dim * local_num_k_heads
local_value_dim = self.head_v_dim * local_num_v_heads
q_proj, k_proj, v_proj = torch.split(mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1)
q_proj = all_to_all(q_proj, process_group=ulysses_group, scatter_dim=2, gather_dim=1, gather_size=get_seq_len("total"))
k_proj = all_to_all(k_proj, process_group=ulysses_group, scatter_dim=2, gather_dim=1, gather_size=get_seq_len("total"))
v_proj = all_to_all(v_proj, process_group=ulysses_group, scatter_dim=2, gather_dim=1, gather_size=get_seq_len("total"))
b = b.reshape(batch_size, seq_len, self.num_v_heads)
a = a.reshape(batch_size, seq_len, self.num_v_heads)
b = all_to_all(b, process_group=ulysses_group, scatter_dim=2, gather_dim=1, gather_size=get_seq_len("total"))
a = all_to_all(a, process_group=ulysses_group, scatter_dim=2, gather_dim=1, gather_size=get_seq_len("total"))
mixed_qkv = torch.cat((q_proj, k_proj, v_proj), dim=-1)
else:
local_num_k_heads = self.num_k_heads
local_num_v_heads = self.num_v_heads
local_key_dim = self.key_dim
local_value_dim = self.value_dim
if use_precomputed_states:
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
mixed_qkv_t = mixed_qkv.transpose(1, 2)
conv_state = F.pad(mixed_qkv_t, (self.conv_kernel_size - mixed_qkv_t.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
if ps.is_ulysses_enable():
conv_weight = self._get_local_conv1d_weight(
ulysses_rank=ulysses_rank,
local_key_dim=local_key_dim,
local_value_dim=local_value_dim,
)
else:
conv_weight = self.conv1d.weight
cu_seqlens = None
if "cu_seqlens" in kwargs and kwargs.get("cu_seqlens") is not None:
cu_seqlens = kwargs.get("cu_seq_lens_q").to(torch.int64)
if self.causal_conv1d_implementation == "triton":
conv_weight = conv_weight.squeeze(1)
mixed_qkv, _ = self.causal_conv1d_fn(
x=mixed_qkv,
weight=conv_weight.transpose(-1, -2).contiguous(),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
elif self.causal_conv1d_implementation == "eager" and self.causal_conv1d_fn is not None:
conv_weight = conv_weight.squeeze(1)
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=conv_weight,
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
mixed_qkv = mixed_qkv.transpose(1, 2)
else:
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(F.conv1d(mixed_qkv, weight=conv_weight, bias=self.conv1d.bias, padding=self.conv_kernel_size - 1, groups=local_key_dim * 2 + local_value_dim)[:, :, :mixed_qkv.shape[-1]])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
local_key_dim,
local_key_dim,
local_value_dim,
],
dim=-1,
)
query = query.reshape(query.shape[0], query.shape[1], local_num_k_heads, self.head_k_dim)
key = key.reshape(key.shape[0], key.shape[1], local_num_k_heads, self.head_k_dim)
value = value.reshape(value.shape[0], value.shape[1], local_num_v_heads, self.head_v_dim)
beta = b.sigmoid()
if ps.is_ulysses_enable():
v_head_offset = ulysses_rank * local_num_v_heads
v_head_slice = slice(v_head_offset, v_head_offset + local_num_v_heads)
g = -self.A_log[v_head_slice].float().exp() * F.softplus(a.float() + self.dt_bias[v_head_slice])
else:
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
if self.gdn_implementation == "triton" or self.gdn_implementation == "AscendC":
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
cu_seqlens=cu_seqlens,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
if ps.is_ulysses_enable():
core_attn_out = all_to_all(core_attn_out, process_group=ulysses_group, scatter_dim=1, gather_dim=2)
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Removes the interleaving of cos and sin from GLM
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
if IS_NPU_AVAILABLE:
q_embed = torch_npu.npu_rotary_mul(q_rot, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k_rot, cos, sin)
else:
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
@use_kernelized_func(apply_rotary_pos_emb)
class Qwen3_5Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3_5Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3_5RMSNorm(self.head_dim, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None,
past_key_values: Cache | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states, gate = torch.chunk(
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
)
gate = gate.reshape(*input_shape, -1)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
total_seq_len = get_seq_len("total")
if get_parallel_state().get_ulysses_group_size() > self.config.num_key_value_heads:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
is_causal=True,
total_seq_len=total_seq_len,
seq_split_lens=None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output * torch.sigmoid(gate)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class Qwen3_5MLP(nn.Module):
def __init__(self, config: Qwen3_5Config, intermediate_size: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen3_5RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
if IS_NPU_AVAILABLE:
return torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.eps)[0]
output = self._norm(x.float())
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class Qwen3_5DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3_5TextConfig, layer_idx: int, is_mtp: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_type = config.layer_types[layer_idx] if not is_mtp else "full_attention"
if self.layer_type == "linear_attention":
self.linear_attn = Qwen3_5GatedDeltaNet(config, layer_idx)
elif self.layer_type == "full_attention":
self.self_attn = Qwen3_5Attention(config, layer_idx)
self.mlp = Qwen3_5MLP(config, config.intermediate_size)
self.input_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
**kwargs,
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3_5PreTrainedModel(PreTrainedModel):
config: Qwen3_5Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3_5DecoderLayer", "Qwen3_5VisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
_keys_to_ignore_on_load_unexpected = [r"^mtp.*"]
_can_record_outputs = {
"hidden_states": Qwen3_5DecoderLayer,
"attentions": Qwen3_5Attention,
}
_is_stateful = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Qwen3_5GatedDeltaNet):
init.ones_(module.dt_bias)
init.copy_(module.A_log, torch.empty_like(module.A_log).uniform_(0, 16).log_())
elif isinstance(module, Qwen3_5RMSNorm):
init.zeros_(module.weight)
elif isinstance(module, Qwen3_5VisionRotaryEmbedding):
inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
init.copy_(module.inv_freq, inv_freq)
class Qwen3_5VisionMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
class Qwen3_5VisionPatchEmbed(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.patch_size = config.patch_size
self.temporal_patch_size = config.temporal_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class Qwen3_5VisionPatchMerger(nn.Module):
def __init__(self, config: Qwen3_5VisionConfig, use_postshuffle_norm=False) -> None:
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.act_fn = nn.GELU()
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ps = get_parallel_state()
if ps.is_cp_enable():
if self.use_postshuffle_norm:
x = packed_data_gather_forward_split_backward_with_cp(x, dim=0, seq_lens=get_seq_len("per_visual"))
x = x.view(-1, self.hidden_size)
split_sizes = cal_split_sizes(x.shape[0], ps.get_ulysses_group_size())
x = split_forward_gather_backward(x, ps.get_ulysses_group(), dim=0, grad_scale="down", split_sizes=split_sizes)
x = self.norm(x)
else:
x = self.norm(x)
x = packed_data_gather_forward_split_backward_with_cp(x, dim=0, seq_lens=get_seq_len("per_visual"))
x = x.view(-1, self.hidden_size)
split_sizes = cal_split_sizes(x.shape[0], ps.get_ulysses_group_size())
x = split_forward_gather_backward(x, ps.get_ulysses_group(), dim=0, grad_scale="down", split_sizes=split_sizes)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
else:
x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
return x
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
if IS_NPU_AVAILABLE:
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
else:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class Qwen3_5VisionAttention(nn.Module):
def __init__(self, config: Qwen3_5VisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
)
if self.config._attn_implementation == "flash_attention_2":
query_states = query_states.unsqueeze(0)
key_states = key_states.unsqueeze(0)
value_states = value_states.unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
total_seq_len=get_seq_len("visual"),
input_layout="1TND",
**kwargs,
)
else:
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class Qwen3_5VisionBlock(GradientCheckpointingLayer):
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
super().__init__()
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
self.attn = Qwen3_5VisionAttention(config=config)
self.mlp = Qwen3_5VisionMLP(config=config)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states
class Qwen3_5VisionModel(Qwen3_5PreTrainedModel):
config: Qwen3_5VisionConfig
_no_split_modules = ["Qwen3_5VisionBlock"]
_can_record_outputs = {
"hidden_states": Qwen3_5VisionBlock,
"attentions": Qwen3_5VisionAttention,
}
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = Qwen3_5VisionPatchEmbed(
config=config,
)
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([Qwen3_5VisionBlock(config) for _ in range(config.depth)])
self.merger = Qwen3_5VisionPatchMerger(
config=config,
use_postshuffle_norm=False,
)
self.gradient_checkpointing = False
self.post_init()
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size
grid_thw_list = grid_thw.tolist()
max_hw = max(max(h, w) for _, h, w in grid_thw_list)
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = sum(t * h * w for t, h, w in grid_thw_list)
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw_list:
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset : offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [row[0] for row in grid_thw_list]
grid_hs = [row[1] for row in grid_thw_list]
grid_ws = [row[2] for row in grid_thw_list]
device = self.pos_embed.weight.device
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for t, h, w in grid_thw_list:
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())
idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])
patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
@merge_with_config_defaults
@capture_outputs
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states.
"""
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
if IS_NPU_AVAILABLE:
cu_seqlens = cu_seqlens.cpu()
sequence_lengths = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cpu()
set_seq_len("visual", seq_len)
set_seq_len("per_visual", sequence_lengths)
ps = get_parallel_state()
if ps.is_cp_enable():
rotary_pos_emb = packed_data_split_forward_gather_backward_with_cp(rotary_pos_emb, dim=0, seq_lens=sequence_lengths)
hidden_states = packed_data_split_forward_gather_backward_with_cp(hidden_states, dim=0, seq_lens=sequence_lengths)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
for blk in self.blocks:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
**kwargs,
)
merged_hidden_states = self.merger(hidden_states)
set_seq_len("visual", seq_len // self.spatial_merge_size ** 2)
if ps.is_cp_enable():
gather_sizes = cal_split_sizes(get_seq_len("visual"), ps.get_ulysses_group_size())
merged_hidden_states = gather_forward_split_backward(
merged_hidden_states,
ps.get_ulysses_group(),
dim=0,
grad_scale="up",
gather_sizes=gather_sizes,
)
set_seq_len("visual", seq_len)
return BaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=merged_hidden_states,
)
@dataclass
@auto_docstring(
custom_intro="""
Base class for Llava outputs, with hidden states and attentions.
"""
)
class Qwen3_5ModelOutputWithPast(ModelOutput):
r"""
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance.
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
last_hidden_state: torch.FloatTensor | None = None
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
rope_deltas: torch.LongTensor | None = None
position_ids: torch.LongTensor | None = None
class Qwen3_5TextModel(Qwen3_5PreTrainedModel):
def __init__(self, config: Qwen3_5TextConfig):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.layers = nn.ModuleList(
[Qwen3_5DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3_5TextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
self.post_init()
@merge_with_config_defaults
@capture_outputs
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = Qwen3_5DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
use_packing = "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None
if use_packing:
causal_mask = None
kwargs.update(generate_ulysses_cu_seqlen_params(text_position_ids, need_cpu_tensor=False))
else:
causal_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=text_position_ids,
)
linear_attn_mask = self._update_linear_attn_mask(attention_mask, cache_position)
kwargs_fa = kwargs
ps = get_parallel_state()
if ps.is_ulysses_enable():
if not use_packing:
kwargs.update(generate_ulysses_cu_seqlen_params(text_position_ids))
else:
kwargs_fa = kwargs.copy()
kwargs_fa["cu_seq_lens_q"] = kwargs_fa["cu_seq_lens_q"].cpu()
kwargs_fa["cu_seq_lens_k"] = kwargs_fa["cu_seq_lens_k"].cpu()
total_seq_len = inputs_embeds.shape[1]
set_seq_len("total", total_seq_len)
if ps.is_ulysses_enable():
position_ids = split_forward_gather_backward_with_cp(position_ids, dim=2)
text_position_ids = split_forward_gather_backward_with_cp(text_position_ids, dim=1)
inputs_embeds = split_forward_gather_backward_with_cp(inputs_embeds, dim=1)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
new_kwargs = kwargs if decoder_layer.layer_type == "linear_attention" else kwargs_fa
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
attention_mask=layer_mask,
position_ids=text_position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**new_kwargs,
)
hidden_states = self.norm(hidden_states)
return Qwen3_5ModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
def _update_linear_attn_mask(self, attention_mask, cache_position):
"""
NOTE: Left-padding is used for linear attention mask.
No need for zeroing states when
1. Cached forward
2. Attending to all inputs
"""
linear_attn_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
linear_attn_mask = None
return linear_attn_mask
@auto_docstring
class Qwen3_5Model(Qwen3_5PreTrainedModel):
base_model_prefix = "model"
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
config: Qwen3_5Config
_no_split_modules = ["Qwen3_5TextDecoderLayer", "Qwen3_5VisionBlock"]
def __init__(self, config):
super().__init__(config)
self.visual = Qwen3_5VisionModel._from_config(config.vision_config)
self.language_model = Qwen3_5TextModel._from_config(config.text_config)
self.rope_deltas = None
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_rope_index(
self,
input_ids: torch.LongTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Different from the original implementation, Qwen3_5 use timestamps rather than absolute time position ids."""
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
image_grid_thw_list = image_grid_thw.tolist() if image_grid_thw is not None else None
video_grid_thw_list = video_grid_thw.tolist() if video_grid_thw is not None else None
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.zeros(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw_list[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw_list[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
@can_return_tuple
@auto_docstring
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.get_image_features(pixel_values_videos, video_grid_thw, **kwargs)
@can_return_tuple
@auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
vision_output: BaseModelOutputWithPooling = self.visual(
pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs
)
image_embeds = vision_output.pooler_output
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
vision_output.pooler_output = image_embeds
return vision_output
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: torch.FloatTensor | None = None,
video_features: torch.FloatTensor | None = None,
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
if image_features is not None:
n_image_tokens = special_image_mask.sum()
if n_image_tokens != image_features.shape[0]:
raise ValueError(
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}"
)
else:
special_image_mask = None
if video_features is not None:
n_video_tokens = special_video_mask.sum()
if n_video_tokens != video_features.shape[0]:
raise ValueError(
f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}"
)
else:
special_video_mask = None
return special_image_mask, special_video_mask
def compute_3d_position_ids(
self,
input_ids: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
image_grid_thw: torch.Tensor | None = None,
video_grid_thw: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
past_key_values: torch.Tensor | None = None,
) -> torch.Tensor | None:
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
can_compute_mrope = input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None)
if can_compute_mrope and (self.rope_deltas is None or past_key_values_length == 0):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
elif self.rope_deltas is not None:
batch_size, seq_length, _ = inputs_embeds.shape
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids = position_ids.masked_fill(attention_mask == 0, 0)
position_ids = position_ids.view(1, batch_size, -1).repeat(3, 1, 1).to(inputs_embeds.device)
else:
position_ids = torch.arange(past_key_values_length, past_key_values_length + seq_length)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1).to(inputs_embeds.device)
delta = self.rope_deltas.repeat_interleave(batch_size // self.rope_deltas.shape[0], dim=0)
position_ids = position_ids + delta.to(device=position_ids.device)
else:
position_ids = None
return position_ids
@auto_docstring
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3_5ModelOutputWithPast:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_outputs: BaseModelOutputWithPooling = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True
)
image_embeds = image_outputs.pooler_output
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
image_indices_tuple = torch.nonzero(image_mask, as_tuple=True)
inputs_embeds[image_indices_tuple] = image_embeds
if pixel_values_videos is not None:
video_outputs: BaseModelOutputWithPooling = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True
)
video_embeds = video_outputs.pooler_output
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
video_indices_tuple = torch.nonzero(video_mask, as_tuple=True)
inputs_embeds[video_indices_tuple] = video_embeds
if position_ids is None:
position_ids = self.compute_3d_position_ids(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
return Qwen3_5ModelOutputWithPast(
**outputs,
rope_deltas=self.rope_deltas,
position_ids=position_ids,
)
@auto_docstring
class Qwen3_5ForCausalLM(Qwen3_5PreTrainedModel, GenerationMixin):
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_gather_output"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config: Qwen3_5TextConfig
_keys_to_ignore_on_load_unexpected = [r"^mtp.*", r"^model.visual.*"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen3_5TextModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3_5ForCausalLM
>>> model = Qwen3_5ForCausalLM.from_pretrained("Qwen/Qwen3_5-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3_5-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@dataclass
@auto_docstring(
custom_intro="""
Base class for Qwen3_5 causal language model (or autoregressive) outputs.
"""
)
class Qwen3_5CausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance.
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
rope_deltas: torch.LongTensor | None = None
mtp_loss: torch.FloatTensor | list[torch.FloatTensor] | None = None
@model_register.register("qwen3_5")
class Qwen3_5ForConditionalGeneration(Qwen3_5PreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
accepts_loss_kwargs = False
config: Qwen3_5Config
def __init__(self, config):
super().__init__(config)
self.model = Qwen3_5Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.enable_mtp = bool(config.text_config.mtp_num_layers)
self.mtp = MultiTokenPredictionBlock(
config.text_config, Qwen3_5DecoderLayer, Qwen3_5RMSNorm
) if self.enable_mtp else None
self.post_init()
@staticmethod
def overwrite_transformer_config(transformer_config, model_args, feature_args):
gdn_implementation = getattr(model_args, "gdn_implementation", "eager")
if gdn_implementation not in ("eager", "triton", "AscendC"):
raise ValueError(f"Invalid gdn_implementation='{gdn_implementation}'. Must be one of: 'eager', 'triton', 'AscendC'.")
transformer_config.text_config.gdn_implementation= gdn_implementation
causal_conv1d_implementation = getattr(model_args, "causal_conv1d_implementation", "eager")
if causal_conv1d_implementation not in ("eager", "triton"):
raise ValueError(f"Invalid causal_conv1d='{causal_conv1d_implementation}'. Must be one of: 'eager', 'triton'.")
transformer_config.text_config.causal_conv1d_implementation = causal_conv1d_implementation
mtp_num_layers = getattr(model_args, "mtp_num_layers", 0)
if mtp_num_layers not in (0, 1):
raise ValueError(f"Invalid mtp_num_layers='{mtp_num_layers}'. Must be one of: 0, 1.")
transformer_config.text_config.mtp_num_layers = mtp_num_layers
transformer_config.text_config.enable_chunk_loss = getattr(feature_args, "enable_chunk_loss", False)
transformer_config.text_config.enable_dynamic_chunk_loss = getattr(feature_args, "enable_dynamic_chunk_loss", False)
return transformer_config
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
@auto_docstring
def get_video_features(
self,
pixel_values_videos: torch.FloatTensor,
video_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.model.get_video_features(
pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
)
@auto_docstring
def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_grid_thw: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPooling:
r"""
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
def _compute_mtp_loss(self, hidden_states: torch.Tensor, **kwargs):
if not self.enable_mtp:
return None
mtp_loss = self.mtp(
hidden_states,
embed_tokens=self.model.language_model.embed_tokens,
rotary_emb=self.model.language_model.rotary_emb,
output_layer=self.lm_head,
loss_function=self.loss_function,
seq_len=get_seq_len("total"),
**kwargs,
)
return mtp_loss
@can_return_tuple
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3_5CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
Example:
```python
>>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
>>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": "./pipeline-cat-chonk.jpeg",
},
{"type": "text", "text": "Describe the image."},
],
}
]
>>> inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
>>> # Generate
>>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
>>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
>>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
>>> print(output_text)
```
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if getattr(self, "enable_chunk_loss", False) or getattr(self, "enable_dynamic_chunk_loss", False):
logits = None
loss = self.lm_head(hidden_states[:, slice_indices, :], self.loss_function)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
ps = get_parallel_state()
if loss is not None and ps.is_cp_enable():
loss = gather_forward_split_backward(loss.unsqueeze(0), ps.get_cp_group(), dim=0)
loss = loss.sum()
final_mtp_loss = self._compute_mtp_loss(
hidden_states,
input_ids=input_ids,
labels=labels,
attention_mask=attention_mask,
position_ids=outputs.position_ids,
past_key_values=outputs.past_key_values,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
**kwargs,
)
return Qwen3_5CausalLMOutputWithPast(
loss=loss,
logits=logits,
mtp_loss=final_mtp_loss,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
is_first_iteration=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
use_cache=use_cache,
is_first_iteration=is_first_iteration,
**kwargs,
)
if not is_first_iteration and use_cache:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
return model_inputs
def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs)
past_length = 0
if (cache := model_kwargs.get("past_key_values")) is not None:
past_length = cache.get_seq_length()
if past_length != 0 and self.model.rope_deltas is not None:
text_positions += self.model.rope_deltas
return text_positions
if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
inputs_tensor = model_kwargs["input_ids"]
is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
if is_input_ids and (
model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None
):
model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"}
vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs)
self.model.rope_deltas = rope_deltas
else:
vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1)
self.model.rope_deltas = torch.zeros(
inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device
)
text_positions = text_positions[None, ...]
position_ids = torch.cat([text_positions, vision_positions], dim=0)
return position_ids
def _get_image_nums_and_video_nums(
self,
input_ids: torch.LongTensor | None,
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: torch.LongTensor | None = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
if video_grid_thw is not None:
cumulative_frame_counts = torch.cumsum(video_grid_thw[:, 0], dim=0)
cumulative_token_video_counts = torch.cumsum(video_nums, dim=0)
video_boundary_indices = torch.searchsorted(cumulative_frame_counts, cumulative_token_video_counts)
video_nums = torch.diff(torch.cat([-video_boundary_indices.new_ones(1), video_boundary_indices]))
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
if key == "pixel_values":
samples = torch.split(image_grid_thw, list(image_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if key == "position_ids" and dict_to_expand[key].ndim == 3:
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1)
elif (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs
__all__ = [
"Qwen3_5VisionModel",
"Qwen3_5TextModel",
"Qwen3_5Model",
"Qwen3_5ForCausalLM",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5PreTrainedModel",
]