from functools import wraps
import torch
from einops import rearrange
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.attention import SelfAttentionSubmodules, SelfAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu
from megatron.core.utils import divide
from megatron.training import get_args
from megatron.core.transformer.spec_utils import build_module
from mindspeed.core.parallel_state import get_tensor_model_parallel_world_size_for_nd1_dim1
from mindspeed.core.tensor_parallel.tp_2d.group_api_2d import TPXCollectiveComm, TPXOverlapCollectiveComm, \
TPYCollectiveComm, TPYOverlapCollectiveComm
from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
from mindspeed.core.fusions.fused_rope import apply_rotary_pos_emb_bshd, apply_rotary_pos_emb
def self_attention_init_tp2d_wrapper(fn):
@wraps(fn)
def wrapper(self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding, ):
args = get_args()
fn(self, config, submodules, layer_number, attn_mask_type)
if args.tp_2d:
attn_heads_split_num = get_tensor_model_parallel_world_size_for_nd1_dim1()
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, attn_heads_split_num)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, attn_heads_split_num)
self.linear_qkv = ParallelLinear2D(
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
ag_comm_intf=TPXCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
rs_comm_intf=TPYCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
enable_overlap_ag_with_matmul=False,
enable_overlap_matmul_with_rs=False,
partition_dim=0,
enable_backward_overlap_ag_with_matmul=False,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
self.linear_proj = ParallelLinear2D(
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
ag_comm_intf=TPYCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
rs_comm_intf=TPXCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
enable_overlap_ag_with_matmul=False,
enable_overlap_matmul_with_rs=False,
partition_dim=1,
enable_backward_overlap_ag_with_matmul=args.enable_backward_overlap_ag_with_matmul,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
return wrapper
def self_attention_init(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
cp_comm_type: str = None):
args = get_args()
super(SelfAttention, self).__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="self",
cp_comm_type=cp_comm_type,
)
if not args.use_g2_attention:
if not args.no_enable_linear_qkv:
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
else:
self.linear_qkv = None
else:
self.linear_q = None
self.linear_kv = None
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
def attention_forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
*,
inference_params=None,
):
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
args = get_args()
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
bsz = query.shape[1]
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_context, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params
cu_seqlens_kv = packed_seq_params
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
)
key = apply_rotary_pos_emb(
key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
)
is_ulysses_algo = (getattr(self.config, 'context_parallel_algo', None) == 'ulysses_cp_algo')
if packed_seq_params is not None and not is_ulysses_algo and args.context_parallel_size > 1:
query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]]
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
if packed_seq_params is not None and not is_ulysses_algo and args.context_parallel_size > 1:
core_attn_out = rearrange(core_attn_out, '(b s) h d -> s b (h d)', b=bsz)
output, bias = self.linear_proj(core_attn_out)
return output, bias