from dataclasses import dataclass, field
from typing import Any, Dict, Union, Optional, Tuple
import torch
from torch import Tensor
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.utils import make_viewless_tensor
from megatron.training import get_args
from mindspeed_llm.core import TransformerLayer
from megatron.core import parallel_state, tensor_parallel
class LongCatFlashTransformerLayer(TransformerLayer):
"""
Inherited from megatron TransformerLayer, we add post norm and post mlp layer norm.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__(config=config,
submodules=submodules,
layer_number=layer_number,
hidden_dropout=hidden_dropout)
self.input_layernorm_0 = build_module(
submodules.input_layernorm_0,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
attention_optional_kwargs = {}
if config.context_parallel_size > 1 and config.cp_comm_type is not None:
if isinstance(config.cp_comm_type, list):
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[self.layer_number]
else:
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type
self.self_attention_0 = build_module(
submodules.self_attention_0,
config=self.config,
layer_number=layer_number,
**attention_optional_kwargs,
)
self.self_attn_bda_0 = build_module(submodules.self_attn_bda_0)
self.pre_mlp_layernorm_0 = build_module(
submodules.pre_mlp_layernorm_0,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.mlp = build_module(
submodules.mlp,
config = self.config
)
self.mlps_0 = build_module(
submodules.mlps_0,
config = self.config
)
self.mlps_bda_0 = build_module(submodules.mlps_bda_0)
self.input_layernorm_1 = build_module(
submodules.input_layernorm_1,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.self_attention_1 = build_module(
submodules.self_attention_1,
config=self.config,
layer_number=layer_number,
**attention_optional_kwargs,
)
self.self_attn_bda_1 = build_module(submodules.self_attn_bda_1)
self.pre_mlp_layernorm_1 = build_module(
submodules.pre_mlp_layernorm_1,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.mlps_1 = build_module(
submodules.mlps_1,
config = self.config
)
self.mlps_bda_1 = build_module(submodules.mlps_bda_1)
def forward(
self,
hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
context: Optional[Tensor] = None,
context_mask: Optional[Tensor] = None,
rotary_pos_emb: Optional[Tensor] = None,
rotary_pos_cos: Optional[Tensor] = None,
rotary_pos_sin: Optional[Tensor] = None,
attention_bias: Optional[Tensor] = None,
inference_context: Optional[Any] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
sequence_len_offset: Optional[Tensor] = None,
*,
inference_params: Optional[Any] = None,
):
"""
Perform a forward pass through the attention layer and the layernorms before and after
the attention operations.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
attention_bias (Tensor, optional): Bias tensor for Q * K.T.
inference_context (object, optional): Parameters for inference-time optimizations.
packed_seq_params (object, optional): Parameters for packed sequence processing.
sequence_len_offset (Tensor, optional): Offset along sequence dimension
during inference.
Returns:
Tuple[Tensor, Tensor, Tensor]: A tuple containing:
pre_mlp_layernorm_output (Tensor): Transformed hidden states before the MLP.
residual (Tensor): Residual connection.
context (Tensor): Updated context tensor if cross-attention is used,
otherwise None.
"""
args = get_args()
inference_context = deprecate_inference_params(inference_context, inference_params)
residual = hidden_states
if self.recompute_input_layernorm:
self.input_layernorm_0_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_0_checkpoint.checkpoint(
self.input_layernorm_0, hidden_states
)
else:
input_layernorm_output = self.input_layernorm_0(hidden_states)
attention_output_with_bias = self.self_attention_0(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
self.input_layernorm_0_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda_0(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
residual = hidden_states
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_layernorm_0_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_layernorm_0_checkpoint.checkpoint(
self.pre_mlp_layernorm_0, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm_0(hidden_states)
shortcut_mlp_output = self.mlp(pre_mlp_layernorm_output)
if self.recompute_mlp:
mlps_0_output_with_bias = tensor_parallel.checkpoint(
self.mlps_0, False, pre_mlp_layernorm_output
)
else:
mlps_0_output_with_bias = self.mlps_0(pre_mlp_layernorm_output)
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_layernorm_0_checkpoint.discard_output_and_register_recompute(
mlps_0_output_with_bias[0]
)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlps_bda_0(self.training, self.config.bias_dropout_fusion)(
mlps_0_output_with_bias, residual, self.hidden_dropout
)
residual = hidden_states
if self.recompute_input_layernorm:
self.input_layernorm_1_checkpoint = tensor_parallel.CheckpointWithoutOutput()
input_layernorm_output = self.input_layernorm_1_checkpoint.checkpoint(
self.input_layernorm_1, hidden_states
)
else:
input_layernorm_output = self.input_layernorm_1(hidden_states)
attention_output_with_bias = self.self_attention_1(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
if self.recompute_input_layernorm:
self.input_layernorm_1_checkpoint.discard_output_and_register_recompute(
attention_output_with_bias[0]
)
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda_1(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
residual = hidden_states
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_layernorm_1_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = self.pre_mlp_layernorm_1_checkpoint.checkpoint(
self.pre_mlp_layernorm_1, hidden_states
)
else:
pre_mlp_layernorm_output = self.pre_mlp_layernorm_1(hidden_states)
if self.recompute_mlp:
mlps_1_output_with_bias = tensor_parallel.checkpoint(
self.mlps_1, False, pre_mlp_layernorm_output
)
else:
mlps_1_output_with_bias = self.mlps_1(pre_mlp_layernorm_output)
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_layernorm_1_checkpoint.discard_output_and_register_recompute(
mlps_1_output_with_bias[0]
)
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlps_bda_1(self.training, self.config.bias_dropout_fusion)(
mlps_1_output_with_bias, residual, self.hidden_dropout
)
hidden_states = hidden_states + shortcut_mlp_output[0]
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return output, context
@dataclass
class LongCatFlashTransformerLayerSubmodules(TransformerLayerSubmodules):
"""
Based on megatron.core.transformer.transformer_layer.TransformerLayerSubmodules,
we add post_attn_norm and post_mlp_layernorm.
"""
input_layernorm_0: Union[ModuleSpec, type] = IdentityOp
self_attention_0: Union[ModuleSpec, type] = IdentityOp
self_attn_bda_0: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm_0: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlps_0: Union[ModuleSpec, type] = IdentityOp
mlps_bda_0: Union[ModuleSpec, type] = IdentityFuncOp
input_layernorm_1: Union[ModuleSpec, type] = IdentityOp
self_attention_1: Union[ModuleSpec, type] = IdentityOp
self_attn_bda_1: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm_1: Union[ModuleSpec, type] = IdentityOp
mlps_1: Union[ModuleSpec, type] = IdentityOp
mlps_bda_1: Union[ModuleSpec, type] = IdentityFuncOp