import logging
import torch
from megatron.training import get_args
from megatron.core import tensor_parallel
from megatron.core.transformer.dot_product_attention import DotProductAttention
from mindspeed.model.transformer import get_attention_mask
from mindspeed_llm.core.transformer.custom_dot_product_attention import CustomDotProductAttentionImpl
logger = logging.getLogger(__name__)
try:
from einops import rearrange
except ImportError:
rearrange = None
class AlibiAttentionImpl(CustomDotProductAttentionImpl):
def __init__(self,
config,
layer_number,
attn_mask_type,
attention_type,
attention_dropout: float = None,
softmax_scale: float = None,
cp_comm_type: str = None):
args = get_args()
args.use_flash_attn = True
super().__init__(config, layer_number, attn_mask_type,
attention_type, attention_dropout,
softmax_scale, cp_comm_type)
args.use_flash_attn = False
def forward(
self,
query,
key,
value,
attention_mask,
attn_mask_type=None,
attention_bias=None,
packed_seq_params=None,
):
if attention_mask is None:
attention_mask = get_attention_mask()
heads_per_gqa_group = self.num_attention_heads_per_partition // self.num_query_groups_per_partition
if heads_per_gqa_group > 1:
key = key.repeat_interleave(heads_per_gqa_group, dim=2)
value = value.repeat_interleave(heads_per_gqa_group, dim=2)
output_size = (
query.size(1),
query.size(2),
query.size(0),
key.size(0),
)
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
if self.alibi.alibi_pse is None or self.alibi.output_size != output_size:
self.alibi.output_size = output_size
self.alibi.get_alibi_pse(attention_mask, output_size[0], output_size[2], output_size[3])
q_trans = query.transpose(0, 1).contiguous()
k_trans = key.transpose(0, 1).transpose(1, 2).contiguous()
matmul_result = self.beta * self.alibi.alibi_pse + torch.bmm(q_trans, k_trans) * (1.0 / self.norm_factor)
if self.attn_logit_softcapping is not None:
matmul_result = matmul_result / self.attn_logit_softcapping
matmul_result = torch.tanh(matmul_result)
matmul_result = matmul_result * self.attn_logit_softcapping
attention_scores = matmul_result.view(*output_size)
if self.square_alibi_mask:
attention_scores = torch.max(
attention_scores, torch.tensor(torch.finfo(attention_scores.dtype).min)
)
attention_probs = torch.nn.functional.softmax(attention_scores, -1)
else:
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
if not self.config.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
output_size = (
value.size(1),
value.size(2),
query.size(0),
value.size(3),
)
value = value.view(value.size(0), output_size[0] * output_size[1], -1)
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
context = torch.bmm(attention_probs, value.transpose(0, 1))
context = context.view(*output_size)
context = context.permute(2, 0, 1, 3).contiguous()
new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
context = context.view(*new_context_shape)
return context
class AlibiAttention(AlibiAttentionImpl, DotProductAttention):
"""
Dot product attention class combining:
- AlibiAttentionImpl: Non-CP + No FlashAttention optimized implementation with alibi
- DotProductAttention: Base attention interface for compatibility with Megatron-LM
"""
def __init__(self, *args, **kwargs):
AlibiAttentionImpl.__init__(self, *args, **kwargs)