# coding=utf-8

# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved.



from collections.abc import Callable

from typing import Optional



import torch

import torch_npu

import torch.nn as nn

import torch.nn.functional as F

from torch.distributed.tensor import DTensor



from transformers.activations import ACT2FN

from transformers.cache_utils import Cache

from transformers.modeling_flash_attention_utils import FlashAttentionKwargs

from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update

from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from transformers.processing_utils import Unpack

from transformers.utils import TransformersKwargs

from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig, Qwen3VLVisionConfig



from megatron.core import mpu

from mindspeed_mm.models.common.communications import cal_split_sizes, gather_forward_split_backward, split_forward_gather_backward

from .utils import get_seq_len, gather_seq_scatter_heads_qkv, gather_heads_scatter_seq





class Qwen3VLEmptyModule(nn.Module):

    """

    This class does not implement any functionality. It serves solely as a placeholder

    to provide a registration point for attaching FSDP2 hooks to all normalization (e.g., LayerNorm, RMSNorm)

    and gate-related parameters when the `align_fsdp_param_groups` feature is enabled.



    Its purpose is structural: to ensure these specific parameters are correctly identified

    and included in FSDP2's parameter grouping and communication logic, without participating

    in forward/backward computation or maintaining any internal state.

    """

    def __init__(self):

        super().__init__()

        

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:

        return hidden_state

    



class Qwen3VLVisionMLP(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.hidden_size = config.hidden_size

        self.intermediate_size = config.intermediate_size

        self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)

        self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)

        self.act_fn = ACT2FN[config.hidden_act]



    def forward(self, hidden_state):

        return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))





class Qwen3VLVisionPatchEmbed(nn.Module):

    def __init__(self, config) -> None:

        super().__init__()

        self.patch_size = config.patch_size

        self.temporal_patch_size = config.temporal_patch_size

        self.in_channels = config.in_channels

        self.embed_dim = config.hidden_size



        kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]

        self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)



    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

        target_dtype = self.proj.weight.dtype

        hidden_states = hidden_states.view(

            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size

        )

        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)

        return hidden_states





class Qwen3VLVisionRotaryEmbedding(nn.Module):

    inv_freq: torch.Tensor  # fix linting for `register_buffer`



    def __init__(self, dim: int, theta: float = 10000.0) -> None:

        super().__init__()

        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))

        self.register_buffer("inv_freq", inv_freq, persistent=False)



    def forward(self, seqlen: int) -> torch.Tensor:

        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(seq, self.inv_freq)

        return freqs





class Qwen3VLVisionPatchMerger(nn.Module):

    def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:

        super().__init__()

        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)

        self.spatial_merge_size = config.spatial_merge_size

        self.use_postshuffle_norm = use_postshuffle_norm

        self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)

        self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)

        self.act_fn = nn.GELU()

        self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)



    def forward(self, x: torch.Tensor) -> torch.Tensor:

        if mpu.get_context_parallel_world_size() > 1:

            actual_seq_len = get_seq_len("visual") * self.spatial_merge_size ** 2

            gather_sizes = cal_split_sizes(actual_seq_len, mpu.get_context_parallel_world_size())

            x = gather_forward_split_backward(

                x,

                mpu.get_context_parallel_group(),

                dim=0,

                grad_scale="up",

                gather_sizes=gather_sizes

            )

            x = x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x

            x = self.norm(x).view(-1, self.hidden_size)

            split_sizes = cal_split_sizes(x.shape[0], mpu.get_context_parallel_world_size())

            x = split_forward_gather_backward(

                    x, 

                    mpu.get_context_parallel_group(),

                    dim=0,

                    grad_scale="down",

                    split_sizes=split_sizes

                )

            x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))

        else:

            x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)

            x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))

        return x





def apply_rotary_pos_emb_vision(

    q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor

) -> tuple[torch.Tensor, torch.Tensor]:

    orig_q_dtype = q.dtype

    orig_k_dtype = k.dtype

    q, k = q.float(), k.float()

    cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()

    cos = cos.unsqueeze(0)

    sin = sin.unsqueeze(0)

    q = q.unsqueeze(0)

    k = k.unsqueeze(0)

    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)

    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)

    q_embed = q_embed.squeeze(0)

    k_embed = k_embed.squeeze(0)

    q_embed = q_embed.to(orig_q_dtype)

    k_embed = k_embed.to(orig_k_dtype)

    return q_embed, k_embed





def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

    """

    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,

    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

    """

    batch, num_key_value_heads, slen, head_dim = hidden_states.shape

    if n_rep == 1:

        return hidden_states

    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)

    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)





def eager_attention_forward(

    module: nn.Module,

    query: torch.Tensor,

    key: torch.Tensor,

    value: torch.Tensor,

    attention_mask: Optional[torch.Tensor],

    scaling: float,

    dropout: float = 0.0,

    **kwargs: Unpack[TransformersKwargs],

):

    key_states = repeat_kv(key, module.num_key_value_groups)

    value_states = repeat_kv(value, module.num_key_value_groups)



    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if attention_mask is not None:

        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

        attn_weights = attn_weights + causal_mask



    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)

    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value_states)

    attn_output = attn_output.transpose(1, 2).contiguous()



    return attn_output, attn_weights





class Qwen3VLVisionAttention(nn.Module):

    def __init__(self, config: Qwen3VLVisionConfig) -> None:

        super().__init__()

        self.dim = config.hidden_size

        self.num_heads = config.num_heads

        self.head_dim = self.dim // self.num_heads

        self.num_key_value_groups = 1  # needed for eager attention

        self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)

        self.proj = nn.Linear(self.dim, self.dim)

        self.scaling = self.head_dim**-0.5

        self.config = config

        self.attention_dropout = 0.0

        self.is_causal = False



    def forward(

        self,

        hidden_states: torch.Tensor,

        cu_seqlens: torch.Tensor,

        rotary_pos_emb: Optional[torch.Tensor] = None,

        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,

        **kwargs,

    ) -> torch.Tensor:

        seq_length = hidden_states.shape[0]

        query_states, key_states, value_states = (

            self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)

        )

        cos, sin = position_embeddings

        query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)



        query_states = query_states.transpose(0, 1).unsqueeze(0)

        key_states = key_states.transpose(0, 1).unsqueeze(0)

        value_states = value_states.transpose(0, 1).unsqueeze(0)



        if mpu.get_context_parallel_world_size() > 1:

            total_visual_seqlen = int(cu_seqlens[-1])

            query_states, key_states, value_states = gather_seq_scatter_heads_qkv(

                query_states,

                key_states,

                value_states,

                seq_dim=2,

                head_dim=1,

                gather_size=total_visual_seqlen

            )



        attention_interface: Callable = eager_attention_forward

        if self.config._attn_implementation != "eager":

            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]



        if self.config._attn_implementation == "flash_attention_2":

            # Flash Attention 2: Use cu_seqlens for variable length attention

            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

            attn_output, _ = attention_interface(

                self,

                query_states,

                key_states,

                value_states,

                attention_mask=None,

                scaling=self.scaling,

                dropout=0.0 if not self.training else self.attention_dropout,

                cu_seq_lens_q=cu_seqlens,

                cu_seq_lens_k=cu_seqlens,

                max_length_q=max_seqlen,

                max_length_k=max_seqlen,

                is_causal=False,

                **kwargs,

            )

        else:

            # Other implementations: Process each chunk separately

            lengths = cu_seqlens[1:] - cu_seqlens[:-1]

            splits = [

                torch.split(tensor, lengths.tolist(), dim=2)

                for tensor in (query_states, key_states, value_states)

            ]



            attn_outputs = [

                attention_interface(

                    self,

                    q,

                    k,

                    v,

                    attention_mask=None,

                    scaling=self.scaling,

                    dropout=0.0 if not self.training else self.attention_dropout,

                    is_causal=False,

                    **kwargs,

                )[0]

                for q, k, v in zip(*splits)

            ]

            attn_output = torch.cat(attn_outputs, dim=1)



        if mpu.get_context_parallel_world_size() > 1:

            attn_output = gather_heads_scatter_seq(

                attn_output,

                seq_dim=1,

                head_dim=2,

                gather_size=self.num_heads

            )



        attn_output = attn_output.reshape(seq_length, -1).contiguous()

        attn_output = self.proj(attn_output)

        return attn_output





class Qwen3VLVisionBlock(nn.Module):

    def __init__(self, config, attn_implementation: str = "sdpa") -> None:

        super().__init__()

        self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)

        self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)

        self.attn = Qwen3VLVisionAttention(config=config)

        self.mlp = Qwen3VLVisionMLP(config=config)



    def forward(

        self,

        hidden_states: torch.Tensor,

        cu_seqlens: torch.Tensor,

        rotary_pos_emb: Optional[torch.Tensor] = None,

        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,

        **kwargs,

    ) -> torch.Tensor:

        hidden_states = hidden_states + self.attn(

            self.norm1(hidden_states),

            cu_seqlens=cu_seqlens,

            rotary_pos_emb=rotary_pos_emb,

            position_embeddings=position_embeddings,

            **kwargs,

        )

        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))

        return hidden_states

    

    

class Qwen3VLTextRotaryEmbedding(nn.Module):

    inv_freq: torch.Tensor  # fix linting for `register_buffer`



    def __init__(self, config: Qwen3VLTextConfig, device=None):

        super().__init__()

        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:

            self.rope_type = config.rope_scaling.get("rope_type", "default")

        else:

            self.rope_type = "default"

        self.max_seq_len_cached = config.max_position_embeddings

        self.original_max_seq_len = config.max_position_embeddings



        self.config = config

        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]



        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.original_inv_freq = self.inv_freq



        self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])



    def apply_interleaved_mrope(self, freqs, mrope_section):

        """Apply interleaved MRoPE to 3D rotary embeddings.

        Reorganizes frequency layout from chunked [TTT...HHH...WWW] to

        interleaved [THTHWHTHW...TT], preserving frequency continuity.

        args:

            x: (3, bs, seq_len, head_dim // 2)

            mrope_section: (3,)

        returns:

            x_t: (bs, seq_len, head_dim // 2)

        """

        freqs_t = freqs[0]  # just overwrite the first dimension T

        for dim, offset in enumerate((1, 2), start=1):  # H, W

            length = mrope_section[dim] * 3

            idx = slice(offset, length, 3)

            freqs_t[..., idx] = freqs[dim, ..., idx]

        return freqs_t



    @torch.no_grad()

    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)

    def forward(self, x, position_ids):

        # In contrast to other models, Qwen3VL has different position ids for the grids

        # So we expand the inv_freq to shape (3, ...)

        if position_ids.ndim == 2:

            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)

        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)



        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"

        with torch.autocast(device_type=device_type, enabled=False):  # Force float32

            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)

            freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)

            emb = torch.cat((freqs, freqs), dim=-1)

            cos = emb.cos() * self.attention_scaling

            sin = emb.sin() * self.attention_scaling



        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)





class Qwen3VLTextRMSNorm(nn.Module):

    def __init__(self, hidden_size, eps: float = 1e-6) -> None:

        """

        Qwen3VLTextRMSNorm is equivalent to T5LayerNorm

        """

        super().__init__()

        self.weight = nn.Parameter(torch.ones(hidden_size))

        self.variance_epsilon = eps



    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

        return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]



    def extra_repr(self):

        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"





def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):

    """Applies Rotary Position Embedding to the query and key tensors.



    Args:

        q (`torch.Tensor`): The query tensor.

        k (`torch.Tensor`): The key tensor.

        cos (`torch.Tensor`): The cosine part of the rotary embedding.

        sin (`torch.Tensor`): The sine part of the rotary embedding.

        position_ids (`torch.Tensor`, *optional*):

            Deprecated and unused.

        unsqueeze_dim (`int`, *optional*, defaults to 1):

            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and

            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note

            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and

            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes

            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have

            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.

    Returns:

        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.

    """

    cos = cos.unsqueeze(unsqueeze_dim)

    sin = sin.unsqueeze(unsqueeze_dim)

    q_embed = torch_npu.npu_rotary_mul(q, cos, sin)

    k_embed = torch_npu.npu_rotary_mul(k, cos, sin)

    return q_embed, k_embed





class Qwen3VLTextAttention(nn.Module):

    """Multi-headed attention from 'Attention Is All You Need' paper"""



    def __init__(self, config: Qwen3VLTextConfig, layer_idx: int):

        super().__init__()

        self.config = config

        self.layer_idx = layer_idx

        self.num_heads = config.num_attention_heads

        self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)

        self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads

        self.scaling = self.head_dim**-0.5

        self.attention_dropout = config.attention_dropout

        self.is_causal = True



        self.q_proj = nn.Linear(

            config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias

        )

        self.k_proj = nn.Linear(

            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias

        )

        self.v_proj = nn.Linear(

            config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias

        )

        self.o_proj = nn.Linear(

            config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias

        )

        self.q_norm = Qwen3VLTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)  # unlike olmo, only on the head dim!

        self.k_norm = Qwen3VLTextRMSNorm(

            self.head_dim, eps=config.rms_norm_eps

        )  # thus post q_norm does not need reshape



    def forward(

        self,

        hidden_states: torch.Tensor,

        position_embeddings: tuple[torch.Tensor, torch.Tensor],

        attention_mask: Optional[torch.Tensor],

        past_key_values: Optional[Cache] = None,

        cache_position: Optional[torch.LongTensor] = None,

        **kwargs: Unpack[FlashAttentionKwargs],

    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:

        input_shape = hidden_states.shape[:-1]

        hidden_shape = (*input_shape, -1, self.head_dim)



        query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)

        key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)

        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)



        cos, sin = position_embeddings # b s d

        

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)



        if past_key_values is not None:

            # sin and cos are specific to RoPE models; cache_position needed for the static cache

            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}

            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)



        attention_interface: Callable = eager_attention_forward

        if self.config._attn_implementation != "eager":

            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]



        total_seq_len = get_seq_len("total")

        if mpu.get_context_parallel_world_size() > 1:

            if mpu.get_context_parallel_world_size() > key_states.shape[1]:

                key_states = repeat_kv(key_states, self.num_key_value_groups)

                value_states = repeat_kv(value_states, self.num_key_value_groups)



            query_states, key_states, value_states = gather_seq_scatter_heads_qkv(

                query_states,

                key_states,

                value_states,

                seq_dim=2,

                head_dim=1,

                gather_size=total_seq_len

            )



        attn_output, attn_weights = attention_interface(

            self,

            query_states,

            key_states,

            value_states,

            attention_mask,

            dropout=0.0 if not self.training else self.attention_dropout,

            scaling=self.scaling,

            **kwargs,

        )



        attn_output = gather_heads_scatter_seq(

            attn_output,

            seq_dim=1,

            head_dim=2,

            gather_size=self.num_heads

        )



        attn_output = attn_output.reshape(*input_shape, -1).contiguous()

        attn_output = self.o_proj(attn_output)

        return attn_output, attn_weights





class Qwen3VLTextMLP(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.config = config

        self.hidden_size = config.hidden_size

        self.intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

        self.act_fn = ACT2FN[config.hidden_act]



    def forward(self, x):

        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

    



class Qwen3VLLMHead(nn.Linear):

    def forward(self, hidden_states: torch.Tensor, loss_ctx: callable = None):

        # Handle distributed tensor (DTensor) weights and biases by converting to local tensors.

        if isinstance(self.weight, DTensor):

            w = self.weight.to_local()

            if self.bias is not None:

                if not isinstance(self.bias, DTensor):

                    raise TypeError(

                        f"Expected bias to be a DTensor when weight is a DTensor, "

                        f"but got bias of type {type(self.bias)}."

                    )

                b = self.bias.to_local()

            else:

                b = None

        else:

            w = self.weight

            b = self.bias

        

        if loss_ctx is None:

            # If no loss context is provided, compute and return logits normally.

            logits = F.linear(hidden_states, w, b)

            return logits, None

        else:

            # Otherwise, delegate loss computation to the provided loss context function,

            # which typically enables memory-efficient or chunked loss calculation.

            return None, loss_ctx(hidden_states, w, b)