# Copyright (c) 2025, HUAWEI CORPORATION.  All rights reserved.

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.



from contextlib import nullcontext

from typing import Optional, Union

import torch

import torch.distributed

from torch import Tensor



from megatron.core import parallel_state, tensor_parallel

from megatron.core.enums import Fp8Recipe

from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

from megatron.core.transformer.transformer_config import TransformerConfig

from megatron.core.transformer.transformer_block import TransformerBlock, TransformerBlockSubmodules

from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor

from megatron.core.transformer.spec_utils import ModuleSpec

from megatron.core.packed_seq_params import PackedSeqParams

from megatron.core.inference.contexts import BaseInferenceContext

from megatron.core.fp8_utils import get_fp8_context

from megatron.core.extensions.transformer_engine import te_checkpoint





class Qwen3vlTransformerBlock(TransformerBlock):

    """Qwen3vl Transformer class."""



    def __init__(

        self,

        config: TransformerConfig,

        spec: Union[TransformerBlockSubmodules, ModuleSpec],

        post_layer_norm: bool = True,

        pre_process: bool = True,

        post_process: bool = True,

    ):

        super().__init__(config=config, spec=spec, post_layer_norm=post_layer_norm, pre_process=pre_process, post_process=post_process)



    def _deepstack_process(

        self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor

    ):

        visual_pos_masks = visual_pos_masks.to(hidden_states.device)

        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)

        local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds

        hidden_states[visual_pos_masks, :] = local_this

        return hidden_states



    def forward(self,

        hidden_states: Union[Tensor, WrappedTensor],

        attention_mask: Optional[Tensor],

        context: Optional[Tensor] = None,

        context_mask: Optional[Tensor] = None,

        rotary_pos_emb: Optional[Tensor] = None,

        rotary_pos_cos: Optional[Tensor] = None,

        rotary_pos_sin: Optional[Tensor] = None,

        attention_bias: Optional[Tensor] = None,

        inference_context: Optional[BaseInferenceContext] = None,

        packed_seq_params: Optional[PackedSeqParams] = None,

        sequence_len_offset: Optional[Tensor] = None,

        # args for deepstack

        deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,

        *,

        inference_params: Optional[BaseInferenceContext] = None,

    ):

        inference_context = deprecate_inference_params(inference_context, inference_params)



        # Delete the obsolete reference to the initial input tensor if necessary

        if isinstance(hidden_states, WrappedTensor):

            hidden_states = hidden_states.unwrap()



        if not self.pre_process:

            # See set_input_tensor()

            hidden_states = self.input_tensor



        # Update the inference parameters with the current batch size in case it is variable

        if inference_context and not self.training:

            inference_context.current_batch_size = hidden_states.size(1)



        hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)



        if self.config.sequence_parallel:

            rng_context = tensor_parallel.get_cuda_rng_tracker().fork()

        else:

            rng_context = nullcontext()



        # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(),

        # otherwise do nothing extra at the outer level

        # if we are using other fp8 recipes, then the context manager enter&exit are free

        # we can wrap fp8_context within the for loop over layers, so that we can fine-grained

        # control which layer will be fp8 or bf16

        use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed

        use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed

        outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()

        with rng_context, outer_fp8_context:

            # Forward pass.

            if self.config.recompute_granularity == 'full' and self.training:

                hidden_states = self._checkpointed_forward(

                    hidden_states=hidden_states,

                    attention_mask=attention_mask,

                    context=context,

                    context_mask=context_mask,

                    rotary_pos_emb=rotary_pos_emb,

                    attention_bias=attention_bias,

                    packed_seq_params=packed_seq_params,

                    deepstack_visual_embeds=deepstack_visual_embeds,

                    use_inner_fp8_context=use_inner_fp8_context,

                )

            else:

                for layer_idx, layer in enumerate(self.layers):

                    inner_fp8_context = (

                        get_fp8_context(self.config, layer.layer_number - 1)

                        if use_inner_fp8_context

                        else nullcontext()

                    )

                    with self.offload_context, inner_fp8_context:

                        hidden_states, context = layer(

                            hidden_states=hidden_states,

                            attention_mask=attention_mask,

                            context=context,

                            context_mask=context_mask,

                            rotary_pos_emb=rotary_pos_emb,

                            rotary_pos_cos=rotary_pos_cos,

                            rotary_pos_sin=rotary_pos_sin,

                            attention_bias=attention_bias,

                            inference_context=inference_context,

                            packed_seq_params=packed_seq_params,

                            sequence_len_offset=sequence_len_offset,

                        )



                        # add visual features to the hidden states of first several layers

                        if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):

                            hidden_states += deepstack_visual_embeds[layer_idx]



                    if (

                        torch.is_grad_enabled()

                        and self.config.cpu_offloading

                        and self.group_prefetch_offload_commit_async is not None

                    ):

                        hidden_states = self.group_prefetch_offload_commit_async(hidden_states)



        # Final layer norm.

        if self.final_layernorm is not None:

            hidden_states = self.final_layernorm(hidden_states)

            # TENorm produces a "viewed" tensor. This will result in schedule.py's

            # deallocate_output_tensor() throwing an error, so a viewless tensor is

            # created to prevent this.

            hidden_states = make_viewless_tensor(

                inp=hidden_states, requires_grad=True, keep_graph=True

            )



        return hidden_states



    # custom checkpoint_forward

    def _checkpointed_forward(

        self,

        hidden_states: Tensor,

        attention_mask: Tensor,

        context: Tensor,

        context_mask: Tensor,

        rotary_pos_emb: Tensor,

        attention_bias: Tensor,

        packed_seq_params: PackedSeqParams,

        deepstack_visual_embeds: list,

        use_inner_fp8_context: bool,

    ):

        """Forward method with activation checkpointing."""



        def custom(start: int, end: int):

            def custom_forward(

                hidden_states, attention_mask, context, context_mask, rotary_pos_emb

            ):

                for index in range(start, end):

                    layer = self._get_layer(index)

                    inner_fp8_context = (

                        get_fp8_context(self.config, layer.layer_number - 1)

                        if use_inner_fp8_context

                        else nullcontext()

                    )

                    with inner_fp8_context:

                        hidden_states, context = layer(

                            hidden_states=hidden_states,

                            attention_mask=attention_mask,

                            context=context,

                            context_mask=context_mask,

                            rotary_pos_emb=rotary_pos_emb,

                            attention_bias=attention_bias,

                            inference_context=None,

                            packed_seq_params=packed_seq_params,

                        )

                return hidden_states, context



            return custom_forward



        def checkpoint_handler(forward_func):

            """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""

            if self.config.fp8:

                return te_checkpoint(

                    forward_func,

                    self.config.distribute_saved_activations,

                    tensor_parallel.random.get_cuda_rng_tracker,

                    parallel_state.get_tensor_model_parallel_group(),

                    hidden_states,

                    attention_mask,

                    context,

                    context_mask,

                    rotary_pos_emb,

                )

            else:

                return tensor_parallel.checkpoint(

                    forward_func,

                    self.config.distribute_saved_activations,

                    hidden_states,

                    attention_mask,

                    context,

                    context_mask,

                    rotary_pos_emb,

                )



        if self.config.recompute_method == 'uniform':

            # Uniformly divide the total number of Transformer layers and checkpoint

            # the input activation of each divided chunk.

            # A method to further reduce memory usage reducing checkpoints.

            layer_idx = 0

            while layer_idx < self.num_layers_per_pipeline_rank:

                hidden_states, context = checkpoint_handler(

                    custom(layer_idx, layer_idx + self.config.recompute_num_layers)

                )



                layer_idx += self.config.recompute_num_layers



        elif self.config.recompute_method == 'block':

            # Checkpoint the input activation of only a set number of individual

            # Transformer layers and skip the rest.

            # A method fully use the device memory removing redundant re-computation.

            recompute_skip_num_layers = 0

            for layer_idx in range(self.num_layers_per_pipeline_rank):

                # Skip recomputation when input grad computation is not needed.

                # Need to have at least one input tensor with gradient computation

                # for re-enterant autograd engine.

                if self.config.fp8 and not hidden_states.requires_grad:

                    recompute_skip_num_layers += 1

                if (

                    layer_idx >= recompute_skip_num_layers

                    and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers

                ):

                    hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))

                else:

                    hidden_states, context = custom(layer_idx, layer_idx + 1)(

                        hidden_states, attention_mask, context, context_mask, rotary_pos_emb

                    )



                # process deepstack visual embeds feature for qwen3vl

                if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):

                    hidden_states += deepstack_visual_embeds[layer_idx]



        else:

            raise ValueError("Invalid activation recompute method.")



        return hidden_states