from collections.abc import Callable
from typing import Any, Optional, Union
from contextlib import nullcontext
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling
from transformers.utils.generic import check_model_inputs
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig
from megatron.core import mpu
from megatron.training import get_args
from mindspeed.core.context_parallel.model_parallel_utils import (
get_context_parallel_group_for_hybrid_ulysses,
get_context_parallel_group_for_hybrid_ring,
get_context_parallel_for_hybrid_ring_world_size,
get_context_parallel_for_hybrid_ulysses_world_size,
get_context_parallel_for_hybrid_ring_rank
)
from mindspeed.utils import set_actual_seq_len, get_actual_seq_len
from mindspeed_mm.models.common.communications import (
cal_split_sizes,
gather_forward_split_backward,
cal_split_sizes_multi,
split_forward_gather_backward_with_cp
)
from mindspeed_mm.utils.async_offload import async_save_on_cpu
from mindspeed_mm.utils.data_balance.data_balance import MBSImageDataBalance
from mindspeed_mm.utils.utils import gather_forward_split_backward_with_megatron_cp, get_packed_seq_len
from ..cp_utils import get_seq_len, set_seq_len, split_visual_seqs_with_cp
from .output import (
Qwen3VLCausalLMOutputWithPast,
Qwen3VLModelOutputWithPast
)
from .modules import (
Qwen3VLTextAttention,
Qwen3VLTextMLP,
Qwen3VLTextRMSNorm,
Qwen3VLVisionPatchEmbed,
Qwen3VLVisionRotaryEmbedding,
Qwen3VLVisionBlock,
Qwen3VLVisionPatchMerger,
Qwen3VLTextRotaryEmbedding,
Qwen3VLLMHead,
Qwen3VLEmptyModule,
)
class Qwen3VLTextDecoderLayer(nn.Module):
def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)
self.mlp = Qwen3VLTextMLP(config)
self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
if self.config.synchronize_per_layer:
torch.npu.current_stream().synchronize()
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Qwen3VLPreTrainedModel(PreTrainedModel):
config: Qwen3VLConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn = True
_supports_sdpa = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3VLTextDecoderLayer,
"attentions": Qwen3VLTextAttention,
}
class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
config: Qwen3VLVisionConfig
_no_split_modules = ["Qwen3VLVisionBlock"]
def __init__(self, config, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = Qwen3VLVisionPatchEmbed(
config=config,
)
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
self.merger = Qwen3VLVisionPatchMerger(
config=config,
use_postshuffle_norm=False,
)
self.deepstack_visual_indexes = config.deepstack_visual_indexes
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3VLVisionPatchMerger(
config=config,
use_postshuffle_norm=True,
)
for _ in range(len(config.deepstack_visual_indexes))
]
)
if config.use_image_mbs_data_balance:
if torch.distributed.get_rank() == 0:
print("[INFO] initialize image mbs data balance")
self.data_balance = MBSImageDataBalance(
sorting_algo_name=config.mbs_data_balance_sorting_algo,
spatial_merge_size=config.spatial_merge_size
)
else:
self.data_balance = None
self.gradient_checkpointing = False
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
merge_size = self.spatial_merge_size
max_hw = int(grid_thw[:, 1:].max().item())
freq_table = self.rotary_pos_emb(max_hw)
device = freq_table.device
total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)
offset = 0
for num_frames, height, width in grid_thw:
merged_h, merged_w = height // merge_size, width // merge_size
block_rows = torch.arange(merged_h, device=device)
block_cols = torch.arange(merged_w, device=device)
intra_row = torch.arange(merge_size, device=device)
intra_col = torch.arange(merge_size, device=device)
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
coords = torch.stack((row_idx, col_idx), dim=-1)
if num_frames > 1:
coords = coords.repeat(num_frames, 1)
num_tokens = coords.shape[0]
pos_ids[offset: offset + num_tokens] = coords
offset += num_tokens
embeddings = freq_table[pos_ids]
embeddings = embeddings.flatten(1)
return embeddings
def fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]
for _, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)
h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor
base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]
weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]
for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())
idx_tensors = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)
weight_tensors = torch.tensor(
weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device
)
patch_idx_tensors = idx_tensors.split([h * w for h, w in zip(grid_hs, grid_ws)], dim=1)
patch_weight_tensors = weight_tensors.split([h * w for h, w in zip(grid_hs, grid_ws)], dim=1)
patch_idx_tensors_permute = []
patch_weight_tensors_permute = []
merge_size = self.config.spatial_merge_size
for idx_tensor, weight_tensor, t, h, w in zip(patch_idx_tensors, patch_weight_tensors, grid_ts, grid_hs, grid_ws):
idx_tensor = idx_tensor.repeat(1, t)
weight_tensor = weight_tensor.repeat(1, t)
idx_tensor = (
idx_tensor.view(4, t, h // merge_size, merge_size, w // merge_size, merge_size)
.permute(0, 1, 2, 4, 3, 5)
.flatten(1, 5)
)
weight_tensor = (
weight_tensor.view(4, t, h // merge_size, merge_size, w // merge_size, merge_size)
.permute(0, 1, 2, 4, 3, 5)
.flatten(1, 5)
)
patch_idx_tensors_permute.append(idx_tensor)
patch_weight_tensors_permute.append(weight_tensor)
patch_idx_tensors_permute = torch.cat(patch_idx_tensors_permute, dim=1)
patch_weight_tensors_permute = torch.cat(patch_weight_tensors_permute, dim=1)
if mpu.get_context_parallel_world_size() > 1:
patch_idx_tensors_permute = split_visual_seqs_with_cp(patch_idx_tensors_permute, dim=1)
patch_weight_tensors_permute = split_visual_seqs_with_cp(patch_weight_tensors_permute, dim=1)
pos_embeds = self.pos_embed(patch_idx_tensors_permute) * patch_weight_tensors_permute[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
return patch_pos_embeds
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.
Returns:
`torch.Tensor`: hidden_states.
"""
if self.data_balance is not None:
hidden_states, grid_thw = self.data_balance.get_image_balance_data(
{'pixel_values': hidden_states, 'image_grid_thw': grid_thw}
)
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
seq_len, _ = hidden_states.size()
sequence_lengths = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cpu()
set_seq_len("per_visual", sequence_lengths)
set_seq_len("visual", seq_len)
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = cu_seqlens[1:] if len(cu_seqlens) > 1 else cu_seqlens
cu_seqlens = tuple(cu_seqlens.cpu().numpy().tolist())
if mpu.get_context_parallel_world_size() > 1:
rotary_pos_emb = split_visual_seqs_with_cp(rotary_pos_emb, dim=0)
hidden_states = split_visual_seqs_with_cp(hidden_states, dim=0)
if get_args().context_parallel_algo == "megatron_cp_algo":
all_split_sizes_tensor = cal_split_sizes_multi(sequence_lengths, mpu.get_context_parallel_world_size())
cu_seqlens = all_split_sizes_tensor.cumsum(dim=1)[mpu.get_context_parallel_rank()]
elif get_args().context_parallel_algo == "hybrid_cp_algo":
all_split_sizes_tensor = cal_split_sizes_multi(sequence_lengths, get_context_parallel_for_hybrid_ring_world_size())
cu_seqlens = all_split_sizes_tensor.cumsum(dim=1)[get_context_parallel_for_hybrid_ring_rank()]
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
position_embeddings=position_embeddings,
**kwargs,
)
if layer_num in self.deepstack_visual_indexes:
deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
hidden_states
)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
if self.data_balance is not None:
hidden_states, deepstack_feature_lists = self.data_balance.reverse_img_balance_data(
hidden_states, deepstack_feature_lists
)
set_seq_len("visual", seq_len // self.spatial_merge_size ** 2)
if mpu.get_context_parallel_world_size() > 1:
gather_sizes = cal_split_sizes(get_seq_len("visual"), mpu.get_context_parallel_world_size())
hidden_states = gather_forward_split_backward(
hidden_states,
mpu.get_context_parallel_group(),
dim=0,
grad_scale="up",
gather_sizes=gather_sizes
)
return hidden_states, deepstack_feature_lists
@auto_docstring(
custom_intro=(
"Text part of Qwen3VL, "
"not a pure text-only model, as DeepStack integrates visual features into the early hidden states."
)
)
class Qwen3VLTextModel(Qwen3VLPreTrainedModel):
config: Qwen3VLTextConfig
_no_split_modules = ["Qwen3VLTextDecoderLayer"]
def __init__(self, config: Qwen3VLTextConfig):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm_hook_module = Qwen3VLEmptyModule()
self.layers = nn.ModuleList(
[Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)
self.gradient_checkpointing = False
if config.activation_offload:
self.swap_stream = torch.npu.Stream()
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
visual_pos_masks: Optional[torch.Tensor] = None,
deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, BaseModelOutputWithPast]:
r"""
visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
The mask of the visual positions.
deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
The feature is extracted from the different visual encoder layers, and fed to the decoder
hidden states.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache(config=self.config)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
total_seq_len = inputs_embeds.shape[1]
set_seq_len("total", total_seq_len)
if self.config.attn_layout == "TND":
if "seqlens" not in kwargs.keys():
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
else:
seqlens_in_batch = kwargs["seqlens"]
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens = cu_seqlens[1:] if len(cu_seqlens) > 1 else cu_seqlens
set_actual_seq_len(actual_seq_len=cu_seqlens)
set_seq_len("total", total_seq_len)
kwargs["cu_seqlens"] = tuple(cu_seqlens.cpu().numpy().tolist())
if "indices" not in kwargs.keys():
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
kwargs["indices"] = indices
else:
if not self.config.is_causal:
attention_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=text_position_ids,
)
else:
attention_mask = None
if mpu.get_context_parallel_world_size() > 1:
position_ids = split_forward_gather_backward_with_cp(position_ids, dim=2)
text_position_ids = split_forward_gather_backward_with_cp(text_position_ids, dim=1)
inputs_embeds = split_forward_gather_backward_with_cp(inputs_embeds, dim=1)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
self.norm_hook_module(hidden_states)
for layer_idx, decoder_layer in enumerate(self.layers):
if self.config.activation_offload:
with async_save_on_cpu(
h2d_stream=self.swap_stream,
d2h_stream=self.swap_stream,
block_idx=layer_idx,
depth=len(self.layers),
custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),
prefetch=True,
):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=text_position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=text_position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs
if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
hidden_states = self._deepstack_process(
hidden_states,
visual_pos_masks,
deepstack_visual_embeds[layer_idx],
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
def _deepstack_process(
self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor
):
if mpu.get_context_parallel_world_size() > 1:
visual_seq_len = get_seq_len("visual")
visual_gather_sizes = cal_split_sizes(visual_seq_len, mpu.get_context_parallel_world_size())
visual_embeds = gather_forward_split_backward(
visual_embeds,
mpu.get_context_parallel_group(),
dim=0,
grad_scale="up",
gather_sizes=visual_gather_sizes
)
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
if mpu.get_context_parallel_world_size() > 1:
megatron_args = get_args()
if megatron_args.context_parallel_algo == "ulysses_cp_algo":
gather_sizes = cal_split_sizes(get_seq_len("total"), mpu.get_context_parallel_world_size())
hidden_states = gather_forward_split_backward(hidden_states, mpu.get_context_parallel_group(), dim=1, grad_scale="up", gather_sizes=gather_sizes)
elif megatron_args.context_parallel_algo == "megatron_cp_algo":
hidden_states = gather_forward_split_backward_with_megatron_cp(hidden_states, mpu.get_context_parallel_group(), dim=1)
elif megatron_args.context_parallel_algo == "hybrid_cp_algo":
actual_seq_len = get_actual_seq_len()
if actual_seq_len is not None:
total_seq_len = get_packed_seq_len(actual_seq_len, get_context_parallel_for_hybrid_ring_world_size())
else:
total_seq_len = get_seq_len("total")
seq_len_per_ring = total_seq_len // get_context_parallel_for_hybrid_ring_world_size()
gather_sizes = cal_split_sizes(seq_len_per_ring, get_context_parallel_for_hybrid_ulysses_world_size())
hidden_states = gather_forward_split_backward(hidden_states, get_context_parallel_group_for_hybrid_ulysses(), dim=1, grad_scale="up", gather_sizes=gather_sizes)
hidden_states = gather_forward_split_backward_with_megatron_cp(hidden_states, get_context_parallel_group_for_hybrid_ring(), dim=1)
else:
raise NotImplementedError(f"Only support `ulysses_cp_algo`,`megatron_cp_algo`,`hybrid_cp_algo`, but got {megatron_args.context_parallel_algo}")
local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
hidden_states[visual_pos_masks, :] = local_this
if mpu.get_context_parallel_world_size() > 1:
hidden_states = split_forward_gather_backward_with_cp(hidden_states, dim=1)
return hidden_states
@auto_docstring
class Qwen3VLModel(Qwen3VLPreTrainedModel):
base_model_prefix = ""
_checkpoint_conversion_mapping = {}
accepts_loss_kwargs = False
config: Qwen3VLConfig
_no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]
def __init__(self, config):
super().__init__(config)
self.visual = Qwen3VLVisionModel._from_config(config.vision_config)
self.language_model = Qwen3VLTextModel._from_config(config.text_config)
self.rope_deltas = None
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.language_model = decoder
def get_decoder(self):
return self.language_model
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sequence_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
if sequence_length is None:
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
attention_mask = attention_mask.to(total_input_ids.device)
else:
total_input_ids = input_ids[0].split(sequence_length.tolist())
max_input_ids_len = max(sequence_length)
position_ids = [None] * len(sequence_length)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
if sequence_length is None:
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
if sequence_length is None:
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
else:
position_ids[i] = llm_positions.to(input_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - max_input_ids_len)
if sequence_length is not None:
position_ids = torch.cat(position_ids, dim=-1).unsqueeze(1)
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
"""
Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
Args:
pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input videos.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
return self.get_image_features(pixel_values_videos, video_grid_thw)
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
if hasattr(self.visual, "_get_fsdp_state") and self.visual._get_fsdp_state()._mp_policy.param_dtype != pixel_values.dtype:
param_dtype = self.visual._get_fsdp_state()._mp_policy.param_dtype
pixel_values = pixel_values.type(param_dtype) if param_dtype is not None else pixel_values
image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return image_embeds, deepstack_image_embeds
def get_placeholder_mask(
self,
input_ids: torch.LongTensor,
inputs_embeds: torch.FloatTensor,
image_features: Optional[torch.FloatTensor] = None,
video_features: Optional[torch.FloatTensor] = None,
):
"""
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if image_features is not None and special_image_mask.sum().item() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
n_video_tokens = special_video_mask.sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if video_features is not None and special_video_mask.sum().item() != video_features.numel():
raise ValueError(
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
)
return special_image_mask, special_video_mask
@auto_docstring
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen3VLModelOutputWithPast]:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_mask = None
video_mask = None
vit_config = get_args().mm.model.image_encoder
context = nullcontext()
if vit_config.vision_encoder.freeze and vit_config.vision_projector.freeze:
context = torch.no_grad()
if pixel_values is not None:
with context:
image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
del image_embeds
if pixel_values_videos is not None:
with context:
video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
del video_embeds
visual_pos_masks = None
deepstack_visual_embeds = None
if image_mask is not None and video_mask is not None:
image_mask = image_mask[..., 0]
video_mask = video_mask[..., 0]
visual_pos_masks = image_mask | video_mask
deepstack_visual_embeds = []
image_mask_joint = image_mask[visual_pos_masks]
video_mask_joint = video_mask[visual_pos_masks]
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
embed_joint[image_mask_joint, :] = img_embed
embed_joint[video_mask_joint, :] = vid_embed
deepstack_visual_embeds.append(embed_joint)
elif image_mask is not None:
image_mask = image_mask[..., 0]
visual_pos_masks = image_mask
deepstack_visual_embeds = deepstack_image_embeds
elif video_mask is not None:
video_mask = video_mask[..., 0]
visual_pos_masks = video_mask
deepstack_visual_embeds = deepstack_video_embeds
if position_ids is None:
attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
if attention_mask_tensor.dtype.is_floating_point:
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
sequence_length=kwargs.get('seqlens', None),
attention_mask=attention_mask_tensor if kwargs.get('seqlens', None) is None else None,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
**kwargs,
)
return Qwen3VLModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
rope_deltas=self.rope_deltas,
)
class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = ["lm_head.weight"]
accepts_loss_kwargs = False
config: Qwen3VLConfig
def __init__(self, config):
super().__init__(config)
self.model = Qwen3VLModel(config)
self.lm_head = Qwen3VLLMHead(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def set_decoder(self, decoder):
self.model.set_decoder(decoder)
def get_decoder(self):
return self.model.get_decoder()
def get_video_features(
self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
):
return self.model.get_video_features(pixel_values_videos, video_grid_thw)
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
return self.model.get_image_features(pixel_values, image_grid_thw)
@property
def language_model(self):
return self.model.language_model
@property
def visual(self):
return self.model.visual
@check_model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
loss_ctx: Optional[callable] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
if loss_ctx:
logits, loss = self.lm_head(hidden_states[:, slice_indices, :], loss_ctx=loss_ctx)
else:
logits, loss = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
return Qwen3VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
rope_deltas=outputs.rope_deltas,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
use_cache=use_cache,
**kwargs,
)
model_inputs["position_ids"] = None
if cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
return model_inputs
def _get_image_nums_and_video_nums(
self,
input_ids: Optional[torch.LongTensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Returns:
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
"""
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
if inputs_embeds is not None:
vision_start_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
image_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
video_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
)[..., 0]
else:
vision_start_mask = input_ids == vision_start_token_id
image_mask = input_ids == image_token_id
video_mask = input_ids == video_token_id
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
return image_nums, video_nums
def _expand_inputs_for_generation(
self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
) -> tuple[torch.LongTensor, dict[str, Any]]:
if expand_size == 1:
return input_ids, model_kwargs
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
def _expand_dict_for_generation_visual(dict_to_expand):
image_grid_thw = model_kwargs.get("image_grid_thw", None)
video_grid_thw = model_kwargs.get("video_grid_thw", None)
image_nums, video_nums = self._get_image_nums_and_video_nums(
input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
)
def _repeat_interleave_samples(x, lengths, repeat_times):
samples = torch.split(x, lengths)
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
return result
for key in dict_to_expand:
if key == "pixel_values":
samples = torch.split(image_grid_thw, list(image_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "image_grid_thw":
lengths = list(image_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "pixel_values_videos":
samples = torch.split(video_grid_thw, list(video_nums))
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "video_grid_thw":
lengths = list(video_nums)
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
)
elif key == "second_per_grid_ts":
dict_to_expand[key] = _repeat_interleave_samples(
dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
)
return dict_to_expand
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and key not in visual_keys
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
model_kwargs = _expand_dict_for_generation(model_kwargs)
if is_encoder_decoder:
if model_kwargs.get("encoder_outputs") is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
return input_ids, model_kwargs