from dataclasses import dataclass
from functools import wraps
from typing import Union
import torch
import torch.nn.functional as F
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
from megatron.core.transformer import TransformerConfig, ModuleSpec, build_module
from megatron.core.transformer.attention import SelfAttentionSubmodules, CrossAttentionSubmodules, \
Attention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core import mpu, parallel_state
from megatron.core.utils import divide
from megatron.training import get_args
from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu
from mindspeed.auto_settings.module.black.patch.hccl_operator import AttentionEndOp
from mindspeed.core.context_parallel.ulysses_context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed.core.parallel_state import get_context_parallel_group_for_hybrid_ulysses, \
get_tensor_model_parallel_world_size_for_nd1_dim1
from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPXOverlapCollectiveComm, \
TPYCollectiveComm, TPYOverlapCollectiveComm
from mindspeed.core.tensor_parallel_y_union_cp import TensorParallelYUnionCP
from mindspeed.core.tensor_parallel.tp_2d.parallel_linear_2d import ParallelLinear2D
@dataclass
class SelfAttentionSubmodules:
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
linear_qb: Union[ModuleSpec, type] = None
linear_kvb: Union[ModuleSpec, type] = None
def attention_init(
self,
config: TransformerConfig,
submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
cp_comm_type: str = None,
):
super(Attention, self).__init__(config=config)
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
args = get_args()
world_size = args.tp_x if args.tp_2d else parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.key_hidden_size = self.hidden_size_per_attention_head
self.val_hidden_size = self.hidden_size_per_attention_head
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
cp_comm_type=cp_comm_type,
)
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
)
cp = config.context_parallel_size
if args.tp_2d:
tp_y_cp_sz = cp * args.tp_y
else:
tp_y_cp_sz = cp
if tp_y_cp_sz > 1 and args.context_parallel_algo in ['ulysses_cp_algo', 'hybrid_cp_algo',
'hybrid_adaptive_cp_algo']:
if args.tp_2d:
tp_y_cp = TensorParallelYUnionCP()
ulysses_group = tp_y_cp.group
else:
ulysses_group = mpu.get_context_parallel_group()
if args.context_parallel_algo in ['hybrid_cp_algo', 'hybrid_adaptive_cp_algo']:
ulysses_group = get_context_parallel_group_for_hybrid_ulysses()
self.core_attention = UlyssesContextAttention(self.core_attention, ulysses_group)
def attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.config.num_query_groups is None:
self.config.num_query_groups = self.config.num_attention_heads
self.num_attention_heads_per_partition = self.config.num_attention_heads * self.num_query_groups_per_partition // self.config.num_query_groups
return wrapper
def self_attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
**attention_optional_kwargs):
args = get_args()
if args.overlap_param_gather:
config.reset_attention_order = True
fn(self, config, submodules, layer_number, attn_mask_type, **attention_optional_kwargs)
if args.multi_head_latent_attention:
self.use_flash_attn = args.use_flash_attn
self.shape_order = args.shape_order
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_nope_head_dim = args.qk_nope_head_dim
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.v_head_dim = args.v_head_dim
query_projection_size = self.config.num_attention_heads * self.v_head_dim
self.q_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
if self.q_lora_rank is None:
self.q_rank = self.config.num_attention_heads * self.q_head_dim
self.q_layernorm = None
else:
self.q_rank = self.q_lora_rank
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.q_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
self.linear_qb = build_module(
submodules.linear_qb,
self.q_lora_rank,
self.config.num_attention_heads * self.q_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qb',
)
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.q_rank + self.kv_lora_rank + self.qk_rope_head_dim,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.kv_lora_rank,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
self.linear_kvb = build_module(
submodules.linear_kvb,
self.kv_lora_rank,
self.config.num_attention_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='kvb',
)
self.linear_proj = build_module(
submodules.linear_proj,
query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
)
if args.tp_2d:
attn_heads_split_num = get_tensor_model_parallel_world_size_for_nd1_dim1()
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, attn_heads_split_num)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, attn_heads_split_num)
self.linear_qkv = ParallelLinear2D(
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
ag_comm_intf=TPXCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
rs_comm_intf=TPYCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
enable_overlap_ag_with_matmul=False,
enable_overlap_matmul_with_rs=False,
partition_dim=0,
enable_backward_overlap_ag_with_matmul=False,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
self.linear_proj = ParallelLinear2D(
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
add_bias=self.config.add_bias_linear,
skip_bias_add=True,
ag_comm_intf=TPYCollectiveComm,
ag_sd_rcv_overlap_comm_intf=TPYOverlapCollectiveComm,
rs_comm_intf=TPXCollectiveComm,
rs_sd_rcv_overlap_comm_intf=TPXOverlapCollectiveComm,
enable_overlap_ag_with_matmul=False,
enable_overlap_matmul_with_rs=False,
partition_dim=1,
enable_backward_overlap_ag_with_matmul=args.enable_backward_overlap_ag_with_matmul,
_initialize_affine_weight_gpu=_initialize_affine_weight_gpu
)
return wrapper
def attention_forward_wrapper(fn):
@wraps(fn)
def wrapper(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
*,
inference_params=None,
):
args = get_args()
if args.prof_file:
from mindspeed.auto_settings.module.black.patch.hccl_operator import AttentionStartOp
hidden_states = AttentionStartOp.apply(hidden_states)
activation_func_1 = torch.nn.Softplus()
hidden_states = activation_func_1(hidden_states)
if args.multi_head_latent_attention:
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
q_len, bsz, _ = hidden_states.shape
mixed_x_layer, _ = self.linear_qkv(hidden_states)
q_a, compressed_kv, k_pe = torch.split(
mixed_x_layer,
[
self.q_rank, self.kv_lora_rank, self.qk_rope_head_dim,
],
dim=-1)
if self.q_layernorm is None:
q = q_a
else:
q, _ = self.linear_qb(self.q_layernorm(q_a))
q = q.view(q_len, bsz, self.config.num_attention_heads, -1)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(q_len, bsz, 1, self.qk_rope_head_dim)
kv, _ = self.linear_kvb(self.k_layernorm(compressed_kv))
kv = kv.view(q_len, bsz, self.config.num_attention_heads, self.qk_nope_head_dim +
self.v_head_dim)
k_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
b, h, s, d = q_pe.shape
q_pe = q_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k_pe.shape
k_pe = k_pe.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
q_pe = apply_rotary_pos_emb(q_pe, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)
k_pe = apply_rotary_pos_emb(k_pe, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
query = torch.cat([q_nope, q_pe], dim=-1)
k_pe = k_pe.repeat(1, 1, query.shape[2], 1)
key = torch.cat([k_nope, k_pe], dim=-1)
if self.use_flash_attn and self.q_head_dim != self.v_head_dim:
if self.shape_order == "BNSD":
value = F.pad(value, [0, self.q_head_dim - self.v_head_dim])
else:
query = F.pad(query, [0, 256 - self.q_head_dim])
key = F.pad(key, [0, 256 - self.q_head_dim])
value = F.pad(value, [0, 256 - self.v_head_dim])
attn_mask_type = AttnMaskType.causal
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
if packed_seq_params is not None:
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
if self.use_flash_attn:
core_attn_out = core_attn_out.view(q_len, bsz, self.config.num_attention_heads, -1)
core_attn_out = core_attn_out[:, :, :, : self.v_head_dim]
core_attn_out = core_attn_out.reshape(q_len, bsz, self.config.num_attention_heads * self.v_head_dim)
output, bias = self.linear_proj(core_attn_out)
else:
output, bias = fn(
self,
hidden_states,
attention_mask,
key_value_states,
inference_context,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
attention_bias,
packed_seq_params,
sequence_len_offset,
inference_params=inference_params,
)
if args.prof_file:
output = AttentionEndOp.apply(output)
activation_func_2 = torch.nn.Softshrink()
output = activation_func_2(output)
return output, bias
return wrapper
def attention_forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_context=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
attention_bias=None,
packed_seq_params=None,
sequence_len_offset=None,
*,
inference_params=None,
):
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin
)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params
cu_seqlens_kv = packed_seq_params
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
)
key = apply_rotary_pos_emb(
key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
)
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
attention_bias=attention_bias,
packed_seq_params=packed_seq_params,
)
output, bias = self.linear_proj(core_attn_out)
return output, bias