# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from contextlib import nullcontext
from typing import Optional, Union

import torch
from torch import Tensor

from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor
from mindspeed.te.pytorch.fp8.checkpoint import checkpoint as te_checkpoint


def transformer_block_checkpointed_forward(
    self,
    hidden_states: Tensor,
    attention_mask: Tensor,
    context: Tensor,
    context_mask: Tensor,
    rotary_pos_emb: Tensor,
    attention_bias: Tensor,
    packed_seq_params: PackedSeqParams,
    use_inner_fp8_context,
):
    """Forward method with activation checkpointing."""
    from megatron.core import parallel_state, tensor_parallel
    from megatron.core.fp8_utils import get_fp8_context

    def custom(start: int, end: int):
        def custom_forward(
            hidden_states, attention_mask, context, context_mask, rotary_pos_emb
        ):
            for index in range(start, end):
                layer = self._get_layer(index)
                inner_fp8_context = (
                    get_fp8_context(self.config, layer.layer_number - 1)
                    if use_inner_fp8_context
                    else nullcontext()
                )
                with inner_fp8_context:
                    hidden_states, context = layer(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        context=context,
                        context_mask=context_mask,
                        rotary_pos_emb=rotary_pos_emb,
                        attention_bias=attention_bias,
                        inference_context=None,
                        packed_seq_params=packed_seq_params,
                    )
            return hidden_states, context

        return custom_forward

    def checkpoint_handler(forward_func):
        """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
        if self.config.fp8:
            return te_checkpoint(
                forward_func,
                self.config.distribute_saved_activations,
                tensor_parallel.random.get_cuda_rng_tracker,
                parallel_state.get_tensor_model_parallel_group(),
                hidden_states,
                attention_mask,
                context,
                context_mask,
                rotary_pos_emb,
            )
        else:
            return tensor_parallel.checkpoint(
                forward_func,
                self.config.distribute_saved_activations,
                hidden_states,
                attention_mask,
                context,
                context_mask,
                rotary_pos_emb,
            )

    if self.config.recompute_method == 'uniform':
        # Uniformly divide the total number of Transformer layers and checkpoint
        # the input activation of each divided chunk.
        # A method to further reduce memory usage reducing checkpoints.
        layer_idx = 0
        while layer_idx < self.num_layers_per_pipeline_rank:
            hidden_states, context = checkpoint_handler(
                custom(layer_idx, layer_idx + self.config.recompute_num_layers)
            )

            layer_idx += self.config.recompute_num_layers

    elif self.config.recompute_method == 'block':
        # Checkpoint the input activation of only a set number of individual
        # Transformer layers and skip the rest.
        # A method fully use the device memory removing redundant re-computation.
        recompute_skip_num_layers = 0
        for layer_idx in range(self.num_layers_per_pipeline_rank):
            # Skip recomputation when input grad computation is not needed.
            # Need to have at least one input tensor with gradient computation
            # for re-enterant autograd engine.
            if self.config.fp8 and not hidden_states.requires_grad:
                recompute_skip_num_layers += 1
            if (
                layer_idx >= recompute_skip_num_layers
                and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers
            ):
                hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
            else:
                hidden_states, context = custom(layer_idx, layer_idx + 1)(
                    hidden_states, attention_mask, context, context_mask, rotary_pos_emb
                )
    else:
        raise ValueError("Invalid activation recompute method.")

    return hidden_states


def transformer_block_forward(
    self,
    hidden_states: Union[Tensor, WrappedTensor],
    attention_mask: Optional[Tensor],
    context: Optional[Tensor] = None,
    context_mask: Optional[Tensor] = None,
    rotary_pos_emb: Optional[Tensor] = None,
    rotary_pos_cos: Optional[Tensor] = None,
    rotary_pos_sin: Optional[Tensor] = None,
    attention_bias: Optional[Tensor] = None,
    inference_context: Optional[BaseInferenceContext] = None,
    packed_seq_params: Optional[PackedSeqParams] = None,
    sequence_len_offset: Optional[Tensor] = None,
    *,
    inference_params: Optional[BaseInferenceContext] = None,
):
    from megatron.core import tensor_parallel
    from megatron.core.enums import Fp8Recipe
    from megatron.core.fp8_utils import get_fp8_context
    inference_context = deprecate_inference_params(inference_context, inference_params)

    # Delete the obsolete reference to the initial input tensor if necessary
    if isinstance(hidden_states, WrappedTensor):
        hidden_states = hidden_states.unwrap()

    if not self.pre_process:
        # See set_input_tensor()
        hidden_states = self.input_tensor

    # Update the inference parameters with the current batch size in case it is variable
    if inference_context and not self.training:
        inference_context.current_batch_size = hidden_states.size(1)

    # Viewless tensor.
    # - We only need to create a viewless tensor in the case of micro batch
    #   size (mbs) == 1, since in this case, 'hidden_states.transpose()'
    #   above creates a view tensor, and '.contiguous()' is a pass-through.
    #   For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
    #   the need to make it viewless.
    #
    #   However, we don't explicitly check mbs == 1 here because
    #   make_viewless_tensor() has negligible overhead when its input
    #   is already viewless.
    #
    # - For the 'else' case above, calling make_viewless_tensor() here is
    #   likely redundant, since p2p_communication.py (likely originator)
    #   already creates viewless tensors. That said, make_viewless_tensor()
    #   is called here to be future-proof and corner-case-proof.
    hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)

    if self.config.sequence_parallel:
        rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
    else:
        rng_context = nullcontext()

    # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(),
    # otherwise do nothing extra at the outer level
    # if we are using other fp8 recipes, then the context manager enter&exit are free
    # we can wrap fp8_context within the for loop over layers, so that we can fine-grained
    # control which layer will be fp8 or bf16
    use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed
    use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed
    outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()

    with rng_context, outer_fp8_context:
        # Forward pass.
        if self.config.recompute_granularity == 'full' and self.training:
            hidden_states = self._checkpointed_forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                context=context,
                context_mask=context_mask,
                rotary_pos_emb=rotary_pos_emb,
                attention_bias=attention_bias,
                packed_seq_params=packed_seq_params,
                # 引入130的修改, fix 内层低精度缺失
                use_inner_fp8_context=use_inner_fp8_context
            )
        else:
            for _, layer in enumerate(self.layers):
                inner_fp8_context = (
                    get_fp8_context(self.config, layer.layer_number - 1)
                    if use_inner_fp8_context
                    else nullcontext()
                )
                with self.offload_context, inner_fp8_context:
                    hidden_states, context = layer(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        context=context,
                        context_mask=context_mask,
                        rotary_pos_emb=rotary_pos_emb,
                        rotary_pos_cos=rotary_pos_cos,
                        rotary_pos_sin=rotary_pos_sin,
                        attention_bias=attention_bias,
                        inference_context=inference_context,
                        packed_seq_params=packed_seq_params,
                        sequence_len_offset=sequence_len_offset,
                    )

                if (
                    torch.is_grad_enabled()
                    and self.config.cpu_offloading
                    and self.group_prefetch_offload_commit_async is not None
                ):
                    hidden_states = self.group_prefetch_offload_commit_async(hidden_states)

    # Final layer norm.
    if self.final_layernorm is not None:
        hidden_states = self.final_layernorm(hidden_states)
        # TENorm produces a "viewed" tensor. This will result in schedule.py's
        # deallocate_output_tensor() throwing an error, so a viewless tensor is
        # created to prevent this.
        hidden_states = make_viewless_tensor(
            inp=hidden_states, requires_grad=True, keep_graph=True
        )

    return hidden_states