from contextlib import nullcontext
from typing import Optional, Union
import torch
from torch import Tensor
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor
from mindspeed.te.pytorch.fp8.checkpoint import checkpoint as te_checkpoint
def transformer_block_checkpointed_forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
context: Tensor,
context_mask: Tensor,
rotary_pos_emb: Tensor,
attention_bias: Tensor,
packed_seq_params: PackedSeqParams,
use_inner_fp8_context,
):
"""Forward method with activation checkpointing."""
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fp8_utils import get_fp8_context
def custom(start: int, end: int):
def custom_forward(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb
):
for index in range(start, end):
layer = self._get_layer(index)
inner_fp8_context = (
get_fp8_context(self.config, layer.layer_number - 1)
if use_inner_fp8_context
else nullcontext()
)
with inner_fp8_context:
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
inference_context=None,
packed_seq_params=packed_seq_params,
)
return hidden_states, context
return custom_forward
def checkpoint_handler(forward_func):
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
if self.config.fp8:
return te_checkpoint(
forward_func,
self.config.distribute_saved_activations,
tensor_parallel.random.get_cuda_rng_tracker,
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
else:
return tensor_parallel.checkpoint(
forward_func,
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
if self.config.recompute_method == 'uniform':
layer_idx = 0
while layer_idx < self.num_layers_per_pipeline_rank:
hidden_states, context = checkpoint_handler(
custom(layer_idx, layer_idx + self.config.recompute_num_layers)
)
layer_idx += self.config.recompute_num_layers
elif self.config.recompute_method == 'block':
recompute_skip_num_layers = 0
for layer_idx in range(self.num_layers_per_pipeline_rank):
if self.config.fp8 and not hidden_states.requires_grad:
recompute_skip_num_layers += 1
if (
layer_idx >= recompute_skip_num_layers
and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers
):
hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1))
else:
hidden_states, context = custom(layer_idx, layer_idx + 1)(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb
)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states
def transformer_block_forward(
self,
hidden_states: Union[Tensor, WrappedTensor],
attention_mask: Optional[Tensor],
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[BaseInferenceContext] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[BaseInferenceContext] = None,
):
from megatron.core import tensor_parallel
from megatron.core.enums import Fp8Recipe
from megatron.core.fp8_utils import get_fp8_context
inference_context = deprecate_inference_params(inference_context, inference_params)
if isinstance(hidden_states, WrappedTensor):
hidden_states = hidden_states.unwrap()
if not self.pre_process:
hidden_states = self.input_tensor
if inference_context and not self.training:
inference_context.current_batch_size = hidden_states.size(1)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
if self.config.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed
use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed
outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext()
with rng_context, outer_fp8_context:
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
use_inner_fp8_context=use_inner_fp8_context
)
else:
for _, layer in enumerate(self.layers):
inner_fp8_context = (
get_fp8_context(self.config, layer.layer_number - 1)
if use_inner_fp8_context
else nullcontext()
)
with self.offload_context, inner_fp8_context:
hidden_states, context = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_context=inference_context,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
if (
torch.is_grad_enabled()
and self.config.cpu_offloading
and self.group_prefetch_offload_commit_async is not None
):
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)
if self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)
hidden_states = make_viewless_tensor(
inp=hidden_states, requires_grad=True, keep_graph=True
)
return hidden_states