# coding=utf-8

# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.



from collections.abc import Callable

from typing import Any, Optional, Union

from contextlib import nullcontext



import torch

import torch.nn as nn

import torch.nn.functional as F



from transformers.cache_utils import Cache, DynamicCache

from transformers.generation import GenerationMixin

from transformers.masking_utils import create_causal_mask

from transformers.modeling_flash_attention_utils import FlashAttentionKwargs

from transformers.modeling_outputs import BaseModelOutputWithPast

from transformers.modeling_utils import PreTrainedModel

from transformers.processing_utils import Unpack

from transformers.utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling

from transformers.utils.generic import check_model_inputs

from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig



from megatron.core import mpu

from megatron.training import get_args

from mindspeed.core.context_parallel.model_parallel_utils import (

    get_context_parallel_group_for_hybrid_ulysses,

    get_context_parallel_group_for_hybrid_ring,

    get_context_parallel_for_hybrid_ring_world_size,

    get_context_parallel_for_hybrid_ulysses_world_size,

    get_context_parallel_for_hybrid_ring_rank

)

from mindspeed.utils import set_actual_seq_len, get_actual_seq_len



from mindspeed_mm.models.common.communications import (

    cal_split_sizes,

    gather_forward_split_backward,

    cal_split_sizes_multi,

    split_forward_gather_backward_with_cp

)

from mindspeed_mm.utils.async_offload import async_save_on_cpu

from mindspeed_mm.utils.data_balance.data_balance import MBSImageDataBalance

from mindspeed_mm.utils.utils import gather_forward_split_backward_with_megatron_cp, get_packed_seq_len



from ..cp_utils import get_seq_len, set_seq_len, split_visual_seqs_with_cp

from .output import (

    Qwen3VLCausalLMOutputWithPast,

    Qwen3VLModelOutputWithPast

)

from .modules import (

    Qwen3VLTextAttention,

    Qwen3VLTextMLP,

    Qwen3VLTextRMSNorm,

    Qwen3VLVisionPatchEmbed,

    Qwen3VLVisionRotaryEmbedding,

    Qwen3VLVisionBlock,

    Qwen3VLVisionPatchMerger,

    Qwen3VLTextRotaryEmbedding,

    Qwen3VLLMHead,

    Qwen3VLEmptyModule,

)





class Qwen3VLTextDecoderLayer(nn.Module):

    def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):

        super().__init__()

        self.config = config

        self.hidden_size = config.hidden_size



        self.self_attn = Qwen3VLTextAttention(config=config, layer_idx=layer_idx)



        self.mlp = Qwen3VLTextMLP(config)

        self.input_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.post_attention_layernorm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)



    def forward(

        self,

        hidden_states: torch.Tensor,

        position_embeddings: tuple[torch.Tensor, torch.Tensor],

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Cache] = None,

        use_cache: Optional[bool] = False,

        cache_position: Optional[torch.LongTensor] = None,

        **kwargs: Unpack[TransformersKwargs],

    ) -> torch.Tensor:

        if self.config.synchronize_per_layer:

            torch.npu.current_stream().synchronize()



        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention

        hidden_states = self.self_attn(

            hidden_states=hidden_states,

            attention_mask=attention_mask,

            position_ids=position_ids,

            past_key_values=past_key_values,

            use_cache=use_cache,

            cache_position=cache_position,

            position_embeddings=position_embeddings,

            **kwargs,

        )

        hidden_states = residual + hidden_states



        # Fully Connected

        residual = hidden_states

        hidden_states = self.post_attention_layernorm(hidden_states)

        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states

        return hidden_states





@auto_docstring

class Qwen3VLPreTrainedModel(PreTrainedModel):

    config: Qwen3VLConfig

    base_model_prefix = "model"

    supports_gradient_checkpointing = True

    _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]

    _skip_keys_device_placement = "past_key_values"

    _supports_flash_attn = True

    _supports_sdpa = True



    _can_compile_fullgraph = True

    _supports_attention_backend = True

    _can_record_outputs = {

        "hidden_states": Qwen3VLTextDecoderLayer,

        "attentions": Qwen3VLTextAttention,

    }





class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):

    config: Qwen3VLVisionConfig

    _no_split_modules = ["Qwen3VLVisionBlock"]



    def __init__(self, config, *inputs, **kwargs) -> None:

        super().__init__(config, *inputs, **kwargs)

        self.spatial_merge_size = config.spatial_merge_size

        self.patch_size = config.patch_size

        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size



        self.patch_embed = Qwen3VLVisionPatchEmbed(

            config=config,

        )



        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)

        self.num_grid_per_side = int(config.num_position_embeddings**0.5)



        head_dim = config.hidden_size // config.num_heads

        self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)



        self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])

        self.merger = Qwen3VLVisionPatchMerger(

            config=config,

            use_postshuffle_norm=False,

        )



        self.deepstack_visual_indexes = config.deepstack_visual_indexes

        self.deepstack_merger_list = nn.ModuleList(

            [

                Qwen3VLVisionPatchMerger(

                    config=config,

                    use_postshuffle_norm=True,

                )

                for _ in range(len(config.deepstack_visual_indexes))

            ]

        )



        if config.use_image_mbs_data_balance:

            if torch.distributed.get_rank() == 0:

                print("[INFO] initialize image mbs data balance")

            self.data_balance = MBSImageDataBalance(

                sorting_algo_name=config.mbs_data_balance_sorting_algo,

                spatial_merge_size=config.spatial_merge_size

            )

        else:

            self.data_balance = None



        self.gradient_checkpointing = False



    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:

        merge_size = self.spatial_merge_size



        max_hw = int(grid_thw[:, 1:].max().item())

        freq_table = self.rotary_pos_emb(max_hw)  # (max_hw, dim // 2)

        device = freq_table.device



        total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())

        pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)



        offset = 0

        for num_frames, height, width in grid_thw:

            merged_h, merged_w = height // merge_size, width // merge_size



            block_rows = torch.arange(merged_h, device=device)  # block row indices

            block_cols = torch.arange(merged_w, device=device)  # block col indices

            intra_row = torch.arange(merge_size, device=device)  # intra-block row offsets

            intra_col = torch.arange(merge_size, device=device)  # intra-block col offsets



            # Compute full-resolution positions

            row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]

            col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]



            row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)

            col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)



            coords = torch.stack((row_idx, col_idx), dim=-1)



            if num_frames > 1:

                coords = coords.repeat(num_frames, 1)



            num_tokens = coords.shape[0]

            pos_ids[offset: offset + num_tokens] = coords

            offset += num_tokens



        embeddings = freq_table[pos_ids]  # lookup rotary embeddings

        embeddings = embeddings.flatten(1)

        return embeddings



    def fast_pos_embed_interpolate(self, grid_thw):

        grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]



        idx_list = [[] for _ in range(4)]

        weight_list = [[] for _ in range(4)]



        for _, h, w in zip(grid_ts, grid_hs, grid_ws):

            # Create coordinate mappings from target resolution to source grid

            # h_idxs: float indices in [0, num_grid_per_side-1] for each pixel row

            # for example: N=16, h=24 → h_idxs = [0, 0.652, 1.304, ..., 15]

            h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)

            w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)



            h_idxs_floor = h_idxs.int()

            w_idxs_floor = w_idxs.int()

            h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

            w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)



            dh = h_idxs - h_idxs_floor # dh ∈ [0,1)

            dw = w_idxs - w_idxs_floor



            base_h = h_idxs_floor * self.num_grid_per_side

            base_h_ceil = h_idxs_ceil * self.num_grid_per_side



            # ========== Compute 4 Corner Indices ==========

            # For bilinear interpolation, we need 4 surrounding grid points

            indices = [

                (base_h[None].T + w_idxs_floor[None]).flatten(),  # top-left

                (base_h[None].T + w_idxs_ceil[None]).flatten(),  # top-right

                (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),  # bottom-left

                (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),  # bottom-right

            ]

            # Weights are based on inverse distance from each corner

            weights = [

                ((1 - dh)[None].T * (1 - dw)[None]).flatten(),  # top-left weight

                ((1 - dh)[None].T * dw[None]).flatten(),

                (dh[None].T * (1 - dw)[None]).flatten(),

                (dh[None].T * dw[None]).flatten(),

            ]

            # Accumulate indices and weights for this image into global lists

            for i in range(4):

                idx_list[i].extend(indices[i].tolist())

                weight_list[i].extend(weights[i].tolist())



        idx_tensors = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device)  # [4, hw1+hw2+hw3...]

        weight_tensors = torch.tensor(

            weight_list, dtype=self.pos_embed.weight.dtype, device=self.pos_embed.weight.device

        )  # [4, hw1+hw2+hw3...]



        # ========== Per-Sample Processing ==========

        # Split combined tensors back into per-sample tensors

        patch_idx_tensors = idx_tensors.split([h * w for h, w in zip(grid_hs, grid_ws)], dim=1)

        patch_weight_tensors = weight_tensors.split([h * w for h, w in zip(grid_hs, grid_ws)], dim=1)



        # Initialize lists for reordered tensors (after spatial merging)

        patch_idx_tensors_permute = []

        patch_weight_tensors_permute = []

        merge_size = self.config.spatial_merge_size

        for idx_tensor, weight_tensor, t, h, w in zip(patch_idx_tensors, patch_weight_tensors, grid_ts, grid_hs, grid_ws):

            idx_tensor = idx_tensor.repeat(1, t)  # 4, thw

            weight_tensor = weight_tensor.repeat(1, t)

            idx_tensor = (

                idx_tensor.view(4, t, h // merge_size, merge_size, w // merge_size, merge_size)

                .permute(0, 1, 2, 4, 3, 5)

                .flatten(1, 5)

            )  # 4, thw

            weight_tensor = (

                weight_tensor.view(4, t, h // merge_size, merge_size, w // merge_size, merge_size)

                .permute(0, 1, 2, 4, 3, 5)

                .flatten(1, 5)

            )  # 4, thw

            patch_idx_tensors_permute.append(idx_tensor)

            patch_weight_tensors_permute.append(weight_tensor)

        patch_idx_tensors_permute = torch.cat(patch_idx_tensors_permute, dim=1)  # [4, s1+s2+s3...]

        patch_weight_tensors_permute = torch.cat(patch_weight_tensors_permute, dim=1)



        # Split tensors across context parallel ranks for distributed processing

        if mpu.get_context_parallel_world_size() > 1:

            patch_idx_tensors_permute = split_visual_seqs_with_cp(patch_idx_tensors_permute, dim=1)

            patch_weight_tensors_permute = split_visual_seqs_with_cp(patch_weight_tensors_permute, dim=1)





        # embedding

        pos_embeds = self.pos_embed(patch_idx_tensors_permute) * patch_weight_tensors_permute[:, :, None]  # 4, total_visual_tokens//cp_size, hidden_size

        patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]



        return patch_pos_embeds



    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:

        """

        Args:

            hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):

                The final hidden states of the model.

            grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):

                The temporal, height and width of feature shape of each image in LLM.



        Returns:

            `torch.Tensor`: hidden_states.

        """

        if self.data_balance is not None:

            hidden_states, grid_thw = self.data_balance.get_image_balance_data(

                {'pixel_values': hidden_states, 'image_grid_thw': grid_thw}

            )



        hidden_states = self.patch_embed(hidden_states) # s1+s2+s3..., h



        rotary_pos_emb = self.rot_pos_emb(grid_thw)



        seq_len, _ = hidden_states.size()



        sequence_lengths = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cpu()



        # Set global sequence length variables for context parallelism

        set_seq_len("per_visual", sequence_lengths)

        set_seq_len("visual", seq_len)

        hidden_states = hidden_states.reshape(seq_len, -1)

        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)



        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(

            dim=0,

            # Select dtype based on the following factors:

            #  - FA2 requires that cu_seqlens_q must have dtype int32

            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw

            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,

        )

        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        cu_seqlens = cu_seqlens[1:] if len(cu_seqlens) > 1 else cu_seqlens

        cu_seqlens = tuple(cu_seqlens.cpu().numpy().tolist())



        # Split sequences across context parallel groups for distributed processing

        if mpu.get_context_parallel_world_size() > 1:

            rotary_pos_emb = split_visual_seqs_with_cp(rotary_pos_emb, dim=0)

            hidden_states = split_visual_seqs_with_cp(hidden_states, dim=0)



            if get_args().context_parallel_algo == "megatron_cp_algo":

                all_split_sizes_tensor = cal_split_sizes_multi(sequence_lengths, mpu.get_context_parallel_world_size())

                # Get cumulative split sizes for the current ring cp rank

                cu_seqlens = all_split_sizes_tensor.cumsum(dim=1)[mpu.get_context_parallel_rank()]

            elif get_args().context_parallel_algo == "hybrid_cp_algo":

                # Calculate split sizes for hybrid ring context parallelism

                all_split_sizes_tensor = cal_split_sizes_multi(sequence_lengths, get_context_parallel_for_hybrid_ring_world_size())

                # Get cumulative split sizes for the current hybrid ring rank

                cu_seqlens = all_split_sizes_tensor.cumsum(dim=1)[get_context_parallel_for_hybrid_ring_rank()]



        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)

        position_embeddings = (emb.cos(), emb.sin())



        # Add fast position embeddings (interpolated based on grid dimensions)

        hidden_states = hidden_states + self.fast_pos_embed_interpolate(grid_thw)



        deepstack_feature_lists = []

        for layer_num, blk in enumerate(self.blocks):

            hidden_states = blk(

                hidden_states,

                cu_seqlens=cu_seqlens,

                position_embeddings=position_embeddings,

                **kwargs,

            )

            if layer_num in self.deepstack_visual_indexes:

                deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](

                    hidden_states

                )

                deepstack_feature_lists.append(deepstack_feature)



        hidden_states = self.merger(hidden_states)

        if self.data_balance is not None:

            hidden_states, deepstack_feature_lists = self.data_balance.reverse_img_balance_data(

                hidden_states, deepstack_feature_lists

            )



        # Gather outputs from all context parallel ranks for the final result

        set_seq_len("visual", seq_len // self.spatial_merge_size ** 2)

        if mpu.get_context_parallel_world_size() > 1:

            gather_sizes = cal_split_sizes(get_seq_len("visual"), mpu.get_context_parallel_world_size())

            hidden_states = gather_forward_split_backward(

                hidden_states,

                mpu.get_context_parallel_group(),

                dim=0,

                grad_scale="up",

                gather_sizes=gather_sizes

            )





        return hidden_states, deepstack_feature_lists





@auto_docstring(

    custom_intro=(

        "Text part of Qwen3VL, "

        "not a pure text-only model, as DeepStack integrates visual features into the early hidden states."

    )

)

class Qwen3VLTextModel(Qwen3VLPreTrainedModel):

    config: Qwen3VLTextConfig

    _no_split_modules = ["Qwen3VLTextDecoderLayer"]



    def __init__(self, config: Qwen3VLTextConfig):

        super().__init__(config)

        self.config = config

        self.padding_idx = config.pad_token_id

        self.vocab_size = config.vocab_size



        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)



        # Placeholder for FSDP2 hook registration on norm/gate params when align_fsdp_param_groups is enabled.

        self.norm_hook_module = Qwen3VLEmptyModule()



        self.layers = nn.ModuleList(

            [Qwen3VLTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]

        )

        self.norm = Qwen3VLTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)

        self.gradient_checkpointing = False



        if config.activation_offload:

            self.swap_stream = torch.npu.Stream()



        # Initialize weights and apply final processing

        self.post_init()



    @check_model_inputs

    @auto_docstring

    def forward(

        self,

        input_ids: Optional[torch.LongTensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Cache] = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        use_cache: Optional[bool] = None,

        cache_position: Optional[torch.LongTensor] = None,

        # args for deepstack

        visual_pos_masks: Optional[torch.Tensor] = None,

        deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,

        **kwargs: Unpack[FlashAttentionKwargs],

    ) -> Union[tuple, BaseModelOutputWithPast]:

        r"""

        visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):

            The mask of the visual positions.

        deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):

            The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).

            The feature is extracted from the different visual encoder layers, and fed to the decoder

            hidden states.

        """

        if (input_ids is None) ^ (inputs_embeds is not None):

            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")



        # torch.jit.trace() doesn't support cache objects in the output

        if use_cache and past_key_values is None and not torch.jit.is_tracing():

            past_key_values = DynamicCache(config=self.config)



        if inputs_embeds is None:

            inputs_embeds = self.embed_tokens(input_ids)



        if cache_position is None:

            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

            cache_position = torch.arange(

                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device

            )



        # the hard coded `3` is for temporal, height and width.

        if position_ids is None:

            position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)

        elif position_ids.ndim == 2:

            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)



        if position_ids.ndim == 3 and position_ids.shape[0] == 4:

            text_position_ids = position_ids[0]

            position_ids = position_ids[1:]

        else:

            text_position_ids = position_ids[0]



        total_seq_len = inputs_embeds.shape[1]

        set_seq_len("total", total_seq_len)



        if self.config.attn_layout == "TND":

            if "seqlens" not in kwargs.keys():

                # if kwargs already have key "seqlens" (i.e. the form of input data is TND), do not calculate seqlens again

                seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)

            else:

                seqlens_in_batch = kwargs["seqlens"]

            cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

            cu_seqlens = cu_seqlens[1:] if len(cu_seqlens) > 1 else cu_seqlens

            set_actual_seq_len(actual_seq_len=cu_seqlens)

            set_seq_len("total", total_seq_len)

            kwargs["cu_seqlens"] = tuple(cu_seqlens.cpu().numpy().tolist())

            if "indices" not in kwargs.keys():

                # if kwargs already have key "indices" (i.e. the form of input data is TND), do not calculate indices again

                indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()

                kwargs["indices"] = indices

        else:

            if not self.config.is_causal:

                attention_mask = create_causal_mask(

                    config=self.config,

                    input_embeds=inputs_embeds,

                    attention_mask=attention_mask,

                    cache_position=cache_position,

                    past_key_values=past_key_values,

                    position_ids=text_position_ids,

                )

            else:

                attention_mask = None



        if mpu.get_context_parallel_world_size() > 1:

            position_ids = split_forward_gather_backward_with_cp(position_ids, dim=2)

            text_position_ids = split_forward_gather_backward_with_cp(text_position_ids, dim=1)

            inputs_embeds = split_forward_gather_backward_with_cp(inputs_embeds, dim=1)



        hidden_states = inputs_embeds



        # create position embeddings to be shared across the decoder layers

        position_embeddings = self.rotary_emb(hidden_states, position_ids)



        self.norm_hook_module(hidden_states)



        # decoder layers

        for layer_idx, decoder_layer in enumerate(self.layers):

            if self.config.activation_offload:

                with async_save_on_cpu(

                    h2d_stream=self.swap_stream,

                    d2h_stream=self.swap_stream,

                    block_idx=layer_idx,

                    depth=len(self.layers),

                    custom_check_fn=lambda x: x.data_ptr() == hidden_states.data_ptr(),

                    prefetch=True,

                ):

                    layer_outputs = decoder_layer(

                    hidden_states,

                    attention_mask=attention_mask,

                    position_ids=text_position_ids,

                    past_key_values=past_key_values,

                    cache_position=cache_position,

                    position_embeddings=position_embeddings,

                    **kwargs,

                    )

            else:

                layer_outputs = decoder_layer(

                    hidden_states,

                    attention_mask=attention_mask,

                    position_ids=text_position_ids,

                    past_key_values=past_key_values,

                    cache_position=cache_position,

                    position_embeddings=position_embeddings,

                    **kwargs,

                )

            hidden_states = layer_outputs



            # add visual features to the hidden states of first several layers

            if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):

                hidden_states = self._deepstack_process(

                    hidden_states,

                    visual_pos_masks,

                    deepstack_visual_embeds[layer_idx],

                )



        hidden_states = self.norm(hidden_states)



        return BaseModelOutputWithPast(

            last_hidden_state=hidden_states,

            past_key_values=past_key_values,

        )



    def _deepstack_process(

        self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor

    ):

        if mpu.get_context_parallel_world_size() > 1:

            visual_seq_len = get_seq_len("visual")

            visual_gather_sizes = cal_split_sizes(visual_seq_len, mpu.get_context_parallel_world_size())

            visual_embeds = gather_forward_split_backward(

                visual_embeds,

                mpu.get_context_parallel_group(),

                dim=0,

                grad_scale="up",

                gather_sizes=visual_gather_sizes

            )



        visual_pos_masks = visual_pos_masks.to(hidden_states.device)

        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)

        if mpu.get_context_parallel_world_size() > 1:

            megatron_args = get_args()

            if megatron_args.context_parallel_algo == "ulysses_cp_algo":

                gather_sizes = cal_split_sizes(get_seq_len("total"), mpu.get_context_parallel_world_size())

                hidden_states = gather_forward_split_backward(hidden_states, mpu.get_context_parallel_group(), dim=1, grad_scale="up", gather_sizes=gather_sizes)

            elif megatron_args.context_parallel_algo == "megatron_cp_algo":

                hidden_states = gather_forward_split_backward_with_megatron_cp(hidden_states, mpu.get_context_parallel_group(), dim=1)

            elif megatron_args.context_parallel_algo == "hybrid_cp_algo":

                # Calculate the sequence length per ring CP group for Ulysses processing.

                # Since padding is applied in ring groups, the division yields an integer.

                actual_seq_len = get_actual_seq_len()

                if actual_seq_len is not None:

                    total_seq_len = get_packed_seq_len(actual_seq_len, get_context_parallel_for_hybrid_ring_world_size())

                else:

                    total_seq_len = get_seq_len("total")

                seq_len_per_ring = total_seq_len // get_context_parallel_for_hybrid_ring_world_size()



                # ulysses allgather

                gather_sizes = cal_split_sizes(seq_len_per_ring, get_context_parallel_for_hybrid_ulysses_world_size())

                hidden_states = gather_forward_split_backward(hidden_states, get_context_parallel_group_for_hybrid_ulysses(), dim=1, grad_scale="up", gather_sizes=gather_sizes)

                # ring allgather

                hidden_states = gather_forward_split_backward_with_megatron_cp(hidden_states, get_context_parallel_group_for_hybrid_ring(), dim=1)

            else:

                raise NotImplementedError(f"Only support `ulysses_cp_algo`,`megatron_cp_algo`,`hybrid_cp_algo`, but got {megatron_args.context_parallel_algo}")

        local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds

        hidden_states[visual_pos_masks, :] = local_this



        # split again

        if mpu.get_context_parallel_world_size() > 1:

            hidden_states = split_forward_gather_backward_with_cp(hidden_states, dim=1)



        return hidden_states





@auto_docstring

class Qwen3VLModel(Qwen3VLPreTrainedModel):

    base_model_prefix = ""

    _checkpoint_conversion_mapping = {}

    # Reference: fix gemma3 grad acc #37208

    accepts_loss_kwargs = False

    config: Qwen3VLConfig

    _no_split_modules = ["Qwen3VLTextDecoderLayer", "Qwen3VLVisionBlock"]



    def __init__(self, config):

        super().__init__(config)

        self.visual = Qwen3VLVisionModel._from_config(config.vision_config)

        self.language_model = Qwen3VLTextModel._from_config(config.text_config)

        self.rope_deltas = None  # cache rope_deltas here



        # Initialize weights and apply final processing

        self.post_init()



    def get_input_embeddings(self):

        return self.language_model.get_input_embeddings()



    def set_input_embeddings(self, value):

        self.language_model.set_input_embeddings(value)



    def set_decoder(self, decoder):

        self.language_model = decoder



    def get_decoder(self):

        return self.language_model



    def get_rope_index(

        self,

        input_ids: Optional[torch.LongTensor] = None,

        image_grid_thw: Optional[torch.LongTensor] = None,

        video_grid_thw: Optional[torch.LongTensor] = None,

        attention_mask: Optional[torch.Tensor] = None,

        sequence_length: Optional[torch.Tensor] = None,

    ) -> tuple[torch.Tensor, torch.Tensor]:

        """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""



        # Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split

        if video_grid_thw is not None:

            video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)

            video_grid_thw[:, 0] = 1



        spatial_merge_size = self.config.vision_config.spatial_merge_size

        image_token_id = self.config.image_token_id

        video_token_id = self.config.video_token_id

        vision_start_token_id = self.config.vision_start_token_id

        mrope_position_deltas = []

        if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):

            if sequence_length is None:

                total_input_ids = input_ids

                if attention_mask is None:

                    attention_mask = torch.ones_like(total_input_ids)

                position_ids = torch.ones(

                    3,

                    input_ids.shape[0],

                    input_ids.shape[1],

                    dtype=input_ids.dtype,

                    device=input_ids.device,

                )

                attention_mask = attention_mask.to(total_input_ids.device)

            else:

                total_input_ids = input_ids[0].split(sequence_length.tolist())

                max_input_ids_len = max(sequence_length)

                position_ids = [None] * len(sequence_length)

            image_index, video_index = 0, 0

            for i, input_ids in enumerate(total_input_ids):

                if sequence_length is None:

                    input_ids = input_ids[attention_mask[i] == 1]

                image_nums, video_nums = 0, 0

                vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)

                vision_tokens = input_ids[vision_start_indices + 1]

                image_nums = (vision_tokens == image_token_id).sum()

                video_nums = (vision_tokens == video_token_id).sum()

                input_tokens = input_ids.tolist()

                llm_pos_ids_list: list = []

                st = 0

                remain_images, remain_videos = image_nums, video_nums

                for _ in range(image_nums + video_nums):

                    if image_token_id in input_tokens and remain_images > 0:

                        ed_image = input_tokens.index(image_token_id, st)

                    else:

                        ed_image = len(input_tokens) + 1

                    if video_token_id in input_tokens and remain_videos > 0:

                        ed_video = input_tokens.index(video_token_id, st)

                    else:

                        ed_video = len(input_tokens) + 1

                    if ed_image < ed_video:

                        t, h, w = (

                            image_grid_thw[image_index][0],

                            image_grid_thw[image_index][1],

                            image_grid_thw[image_index][2],

                        )

                        image_index += 1

                        remain_images -= 1

                        ed = ed_image



                    else:

                        t, h, w = (

                            video_grid_thw[video_index][0],

                            video_grid_thw[video_index][1],

                            video_grid_thw[video_index][2],

                        )

                        video_index += 1

                        remain_videos -= 1

                        ed = ed_video

                    llm_grid_t, llm_grid_h, llm_grid_w = (

                        t.item(),

                        h.item() // spatial_merge_size,

                        w.item() // spatial_merge_size,

                    )

                    text_len = ed - st



                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)



                    # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)

                    t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()

                    h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()

                    w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()

                    llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)

                    st = ed + llm_grid_t * llm_grid_h * llm_grid_w



                if st < len(input_tokens):

                    st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0

                    text_len = len(input_tokens) - st

                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)



                llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)

                if sequence_length is None:

                    position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)

                    mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))

                else:

                    position_ids[i] = llm_positions.to(input_ids.device)

                    mrope_position_deltas.append(llm_positions.max() + 1 - max_input_ids_len)

            if sequence_length is not None:

                position_ids = torch.cat(position_ids, dim=-1).unsqueeze(1)

            mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)

            return position_ids, mrope_position_deltas

        else:

            if attention_mask is not None:

                position_ids = attention_mask.long().cumsum(-1) - 1

                position_ids.masked_fill_(attention_mask == 0, 1)

                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)

                max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]

                mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]

            else:

                position_ids = (

                    torch.arange(input_ids.shape[1], device=input_ids.device)

                    .view(1, 1, -1)

                    .expand(3, input_ids.shape[0], -1)

                )

                mrope_position_deltas = torch.zeros(

                    [input_ids.shape[0], 1],

                    device=input_ids.device,

                    dtype=input_ids.dtype,

                )



            return position_ids, mrope_position_deltas



    def get_video_features(

        self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None

    ):

        """

        Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.



        Args:

            pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):

                The tensors corresponding to the input videos.

            video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):

                The temporal, height and width of feature shape of each video in LLM.

        """

        # Same implementation as for images

        return self.get_image_features(pixel_values_videos, video_grid_thw)



    def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):

        """

        Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned.



        Args:

            pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):

                The tensors corresponding to the input images.

            image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):

                The temporal, height and width of feature shape of each image in LLM.

        """

        if hasattr(self.visual, "_get_fsdp_state") and self.visual._get_fsdp_state()._mp_policy.param_dtype != pixel_values.dtype:

            param_dtype = self.visual._get_fsdp_state()._mp_policy.param_dtype

            pixel_values = pixel_values.type(param_dtype) if param_dtype is not None else pixel_values

        image_embeds, deepstack_image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)



        split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()

        image_embeds = torch.split(image_embeds, split_sizes)

        return image_embeds, deepstack_image_embeds



    def get_placeholder_mask(

        self,

        input_ids: torch.LongTensor,

        inputs_embeds: torch.FloatTensor,

        image_features: Optional[torch.FloatTensor] = None,

        video_features: Optional[torch.FloatTensor] = None,

    ):

        """

        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is

        equal to the length of multimodal features. If the lengths are different, an error is raised.

        """

        if input_ids is None:

            special_image_mask = inputs_embeds == self.get_input_embeddings()(

                torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)

            )

            special_image_mask = special_image_mask.all(-1)

            special_video_mask = inputs_embeds == self.get_input_embeddings()(

                torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)

            )

            special_video_mask = special_video_mask.all(-1)

        else:

            special_image_mask = input_ids == self.config.image_token_id

            special_video_mask = input_ids == self.config.video_token_id



        n_image_tokens = special_image_mask.sum()

        special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)# b s h

        if image_features is not None and special_image_mask.sum().item() != image_features.numel():

            raise ValueError(

                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"

            )



        n_video_tokens = special_video_mask.sum()

        special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)

        if video_features is not None and special_video_mask.sum().item() != video_features.numel():

            raise ValueError(

                f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"

            )



        return special_image_mask, special_video_mask



    @auto_docstring

    @check_model_inputs

    def forward(

        self,

        input_ids: torch.LongTensor = None,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Cache] = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        pixel_values: Optional[torch.Tensor] = None,

        pixel_values_videos: Optional[torch.FloatTensor] = None,

        image_grid_thw: Optional[torch.LongTensor] = None,

        video_grid_thw: Optional[torch.LongTensor] = None,

        cache_position: Optional[torch.LongTensor] = None,

        **kwargs: Unpack[TransformersKwargs],

    ) -> Union[tuple, Qwen3VLModelOutputWithPast]:

        r"""

        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):

            The temporal, height and width of feature shape of each image in LLM.

        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):

            The temporal, height and width of feature shape of each video in LLM.

        """

        if (input_ids is None) ^ (inputs_embeds is not None):

            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")



        if inputs_embeds is None:

            inputs_embeds = self.get_input_embeddings()(input_ids)



        image_mask = None

        video_mask = None



        vit_config = get_args().mm.model.image_encoder

        context = nullcontext()

        if vit_config.vision_encoder.freeze and vit_config.vision_projector.freeze:

            context = torch.no_grad()

        if pixel_values is not None:

            with context:

                image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw)

            image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)

            image_mask, _ = self.get_placeholder_mask(

                input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds

            )

            inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            # for releasing space

            del image_embeds



        if pixel_values_videos is not None:

            with context:

                video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)

            video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)

            _, video_mask = self.get_placeholder_mask(

                input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds

            )

            inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            # for releasing space

            del video_embeds



        visual_pos_masks = None

        deepstack_visual_embeds = None

        if image_mask is not None and video_mask is not None:

            # aggregate visual_pos_masks and deepstack_visual_embeds

            image_mask = image_mask[..., 0]

            video_mask = video_mask[..., 0]

            visual_pos_masks = image_mask | video_mask

            deepstack_visual_embeds = []

            image_mask_joint = image_mask[visual_pos_masks]

            video_mask_joint = video_mask[visual_pos_masks]

            for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):

                embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)

                embed_joint[image_mask_joint, :] = img_embed

                embed_joint[video_mask_joint, :] = vid_embed

                deepstack_visual_embeds.append(embed_joint)

        elif image_mask is not None:

            image_mask = image_mask[..., 0]

            visual_pos_masks = image_mask

            deepstack_visual_embeds = deepstack_image_embeds

        elif video_mask is not None:

            video_mask = video_mask[..., 0]

            visual_pos_masks = video_mask

            deepstack_visual_embeds = deepstack_video_embeds



        if position_ids is None:

            attention_mask_tensor = (

                attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]

            )

            if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:

                attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)

                # Only apply conversion for floating point tensors (inverted masks)

                if attention_mask_tensor.dtype.is_floating_point:

                    attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min

                    attention_mask_tensor = (1.0 - attention_mask_tensor).int()



            # Calculate RoPE index once per generation in the pre-fill stage only.

            # When compiling, we can't check tensor values thus we check only input length

            # It is safe to assume that `length!=1` means we're in pre-fill because compiled

            # models currently cannot do asssisted decoding

            prefill_compiled_stage = is_torchdynamo_compiling() and (

                (input_ids is not None and input_ids.shape[1] != 1)

                or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)

            )

            prefill_noncompiled_stage = not is_torchdynamo_compiling() and (

                (cache_position is not None and cache_position[0] == 0)

                or (past_key_values is None or past_key_values.get_seq_length() == 0)

            )

            if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:

                position_ids, rope_deltas = self.get_rope_index(

                    input_ids,

                    image_grid_thw,

                    video_grid_thw,

                    sequence_length=kwargs.get('seqlens', None),

                    attention_mask=attention_mask_tensor if kwargs.get('seqlens', None) is None else None,

                )

                self.rope_deltas = rope_deltas

            # then use the prev pre-calculated rope-deltas to get the correct position ids

            else:

                batch_size, seq_length, _ = inputs_embeds.shape

                delta = (

                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)

                    if cache_position is not None

                    else 0

                )

                position_ids = torch.arange(seq_length, device=inputs_embeds.device)

                position_ids = position_ids.view(1, -1).expand(batch_size, -1)

                if cache_position is not None:  # otherwise `deltas` is an int `0`

                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)

                position_ids = position_ids.add(delta)

                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)



        outputs = self.language_model(

            input_ids=None,

            position_ids=position_ids,

            attention_mask=attention_mask,

            past_key_values=past_key_values,

            inputs_embeds=inputs_embeds,

            cache_position=cache_position,

            visual_pos_masks=visual_pos_masks,

            deepstack_visual_embeds=deepstack_visual_embeds,

            **kwargs,

        )



        return Qwen3VLModelOutputWithPast(

            last_hidden_state=outputs.last_hidden_state,

            past_key_values=outputs.past_key_values,

            rope_deltas=self.rope_deltas,

        )





class Qwen3VLForConditionalGeneration(Qwen3VLPreTrainedModel, GenerationMixin):

    _checkpoint_conversion_mapping = {}

    _tied_weights_keys = ["lm_head.weight"]

    # Reference: fix gemma3 grad acc #37208

    accepts_loss_kwargs = False

    config: Qwen3VLConfig



    def __init__(self, config):

        super().__init__(config)

        self.model = Qwen3VLModel(config)

        self.lm_head = Qwen3VLLMHead(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)



        self.post_init()



    def get_input_embeddings(self):

        return self.model.get_input_embeddings()



    def set_input_embeddings(self, value):

        self.model.set_input_embeddings(value)



    def set_decoder(self, decoder):

        self.model.set_decoder(decoder)



    def get_decoder(self):

        return self.model.get_decoder()



    def get_video_features(

        self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None

    ):

        return self.model.get_video_features(pixel_values_videos, video_grid_thw)



    def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):

        return self.model.get_image_features(pixel_values, image_grid_thw)



    # Make modules available through conditional class for BC

    @property

    def language_model(self):

        return self.model.language_model



    @property

    def visual(self):

        return self.model.visual



    @check_model_inputs

    def forward(

        self,

        input_ids: torch.LongTensor = None,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        past_key_values: Optional[Cache] = None,

        inputs_embeds: Optional[torch.FloatTensor] = None,

        labels: Optional[torch.LongTensor] = None,

        pixel_values: Optional[torch.Tensor] = None,

        pixel_values_videos: Optional[torch.FloatTensor] = None,

        image_grid_thw: Optional[torch.LongTensor] = None,

        video_grid_thw: Optional[torch.LongTensor] = None,

        cache_position: Optional[torch.LongTensor] = None,

        logits_to_keep: Union[int, torch.Tensor] = 0,

        loss_ctx: Optional[callable] = None,

        **kwargs: Unpack[TransformersKwargs],

    ) -> Union[tuple, Qwen3VLCausalLMOutputWithPast]:

        r"""

        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):

            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored

            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):

            The temporal, height and width of feature shape of each image in LLM.

        video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):

            The temporal, height and width of feature shape of each video in LLM.

        """

        outputs = self.model(

            input_ids=input_ids,

            pixel_values=pixel_values,

            pixel_values_videos=pixel_values_videos,

            image_grid_thw=image_grid_thw,

            video_grid_thw=video_grid_thw,

            position_ids=position_ids,

            attention_mask=attention_mask,

            past_key_values=past_key_values,

            inputs_embeds=inputs_embeds,

            cache_position=cache_position,

            **kwargs,

        )



        hidden_states = outputs[0]



        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep



        if loss_ctx:

            logits, loss = self.lm_head(hidden_states[:, slice_indices, :], loss_ctx=loss_ctx)

        else:

            logits, loss = self.lm_head(hidden_states[:, slice_indices, :])

            if labels is not None:

                loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)



        return Qwen3VLCausalLMOutputWithPast(

            loss=loss,

            logits=logits,

            past_key_values=outputs.past_key_values,

            rope_deltas=outputs.rope_deltas,

        )



    def prepare_inputs_for_generation(

        self,

        input_ids,

        past_key_values=None,

        attention_mask=None,

        inputs_embeds=None,

        cache_position=None,

        position_ids=None,

        use_cache=True,

        pixel_values=None,

        pixel_values_videos=None,

        image_grid_thw=None,

        video_grid_thw=None,

        **kwargs,

    ):

        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model



        model_inputs = super().prepare_inputs_for_generation(

            input_ids,

            past_key_values=past_key_values,

            attention_mask=attention_mask,

            inputs_embeds=inputs_embeds,

            cache_position=cache_position,

            position_ids=position_ids,

            pixel_values=pixel_values,

            pixel_values_videos=pixel_values_videos,

            image_grid_thw=image_grid_thw,

            video_grid_thw=video_grid_thw,

            use_cache=use_cache,

            **kwargs,

        )



        # Qwen3VL position_ids are prepareed with rope_deltas in forward

        model_inputs["position_ids"] = None



        if cache_position[0] != 0:

            model_inputs["pixel_values"] = None

            model_inputs["pixel_values_videos"] = None



        return model_inputs



    def _get_image_nums_and_video_nums(

        self,

        input_ids: Optional[torch.LongTensor],

        inputs_embeds: Optional[torch.Tensor] = None,

    ) -> tuple[torch.Tensor, torch.Tensor]:

        """

        Get the number of images and videos for each sample to calculate the separation length of the sample tensor.

        These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.



        Args:

            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):

                Indices of input sequence tokens in the vocabulary.



        Returns:

            image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)

            video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)

        """

        image_token_id = self.config.image_token_id

        video_token_id = self.config.video_token_id

        vision_start_token_id = self.config.vision_start_token_id



        if inputs_embeds is not None:

            vision_start_mask = (

                inputs_embeds

                == self.get_input_embeddings()(

                    torch.tensor(vision_start_token_id, dtype=torch.long, device=inputs_embeds.device)

                )

            )[..., 0]

            image_mask = (

                inputs_embeds

                == self.get_input_embeddings()(

                    torch.tensor(image_token_id, dtype=torch.long, device=inputs_embeds.device)

                )

            )[..., 0]

            video_mask = (

                inputs_embeds

                == self.get_input_embeddings()(

                    torch.tensor(video_token_id, dtype=torch.long, device=inputs_embeds.device)

                )

            )[..., 0]

        else:

            vision_start_mask = input_ids == vision_start_token_id

            image_mask = input_ids == image_token_id

            video_mask = input_ids == video_token_id



        vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)

        image_nums = torch.sum(vision_first_mask & image_mask, dim=1)

        video_nums = torch.sum(vision_first_mask & video_mask, dim=1)



        return image_nums, video_nums



    def _expand_inputs_for_generation(

        self,

        expand_size: int = 1,

        is_encoder_decoder: bool = False,

        input_ids: Optional[torch.LongTensor] = None,

        **model_kwargs,

    ) -> tuple[torch.LongTensor, dict[str, Any]]:

        # Overwritten -- Support for expanding tensors without a batch size dimension

        # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t

        # pixel_values.shape[0] is sum(seqlen_images for samples)

        # image_grid_thw.shape[0] is sum(num_images for samples)



        if expand_size == 1:

            return input_ids, model_kwargs



        visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]



        def _expand_dict_for_generation_visual(dict_to_expand):

            image_grid_thw = model_kwargs.get("image_grid_thw", None)

            video_grid_thw = model_kwargs.get("video_grid_thw", None)

            image_nums, video_nums = self._get_image_nums_and_video_nums(

                input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)

            )



            def _repeat_interleave_samples(x, lengths, repeat_times):

                samples = torch.split(x, lengths)

                repeat_args = [repeat_times] + [1] * (x.dim() - 1)

                result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)

                return result



            for key in dict_to_expand:

                if key == "pixel_values":

                    # split images into samples

                    samples = torch.split(image_grid_thw, list(image_nums))

                    # compute the sequence length of images for each sample

                    lengths = [torch.prod(sample, dim=1).sum() for sample in samples]

                    dict_to_expand[key] = _repeat_interleave_samples(

                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size

                    )

                elif key == "image_grid_thw":

                    # get the num of images for each sample

                    lengths = list(image_nums)

                    dict_to_expand[key] = _repeat_interleave_samples(

                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size

                    )

                elif key == "pixel_values_videos":

                    samples = torch.split(video_grid_thw, list(video_nums))

                    lengths = [torch.prod(sample, dim=1).sum() for sample in samples]

                    dict_to_expand[key] = _repeat_interleave_samples(

                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size

                    )

                elif key == "video_grid_thw":

                    lengths = list(video_nums)

                    dict_to_expand[key] = _repeat_interleave_samples(

                        dict_to_expand[key], lengths=lengths, repeat_times=expand_size

                    )

                elif key == "second_per_grid_ts":

                    dict_to_expand[key] = _repeat_interleave_samples(

                        dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size

                    )

            return dict_to_expand



        def _expand_dict_for_generation(dict_to_expand):

            for key in dict_to_expand:

                if (

                    key != "cache_position"

                    and dict_to_expand[key] is not None

                    and isinstance(dict_to_expand[key], torch.Tensor)

                    and key not in visual_keys

                ):

                    dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)

            return dict_to_expand



        model_kwargs = _expand_dict_for_generation_visual(model_kwargs)



        if input_ids is not None:

            input_ids = input_ids.repeat_interleave(expand_size, dim=0)



        model_kwargs = _expand_dict_for_generation(model_kwargs)



        if is_encoder_decoder:

            if model_kwargs.get("encoder_outputs") is None:

                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")

            model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])



        return input_ids, model_kwargs