import math
from typing import Optional
import torch
from torch import Tensor
import torch_npu
from megatron.core.transformer.dot_product_attention import DotProductAttention as MegatronDotProductAttention
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.packed_seq_params import PackedSeqParams
try:
from einops import rearrange
except ImportError:
rearrange = None
def dot_product_attention_forward_impl(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
attention_bias: Tensor = None,
packed_seq_params: Optional[PackedSeqParams] = None,
):
if packed_seq_params is None:
seq_length, bsz, n_head, head_dim = (
query.shape[0], query.shape[1], query.shape[2], query.shape[3]
)
else:
seq_length, n_head, head_dim = (
query.shape[0], query.shape[1], query.shape[2]
)
use_remove_padding = getattr(self.config, 'use_remove_padding', False)
if use_remove_padding:
from mindspeed.utils import get_actual_seq_len
seq_length, bsz, n_head, head_dim = query.shape[0], query.shape[1], query.shape[2], query.shape[3]
actual_seq_len = get_actual_seq_len()
query, key, value = [rearrange(x, 's b h d -> (b s) h d') for x in [query, key, value]]
attention_mask_npu = torch.triu(
torch.ones([2048, 2048], dtype=torch.bool, device=query.device), diagonal=1)
output = torch_npu.npu_fusion_attention(
query, key, value, n_head,
pse=None,
padding_mask=None,
atten_mask=attention_mask_npu,
scale=1.0 / math.sqrt(query.shape[-1]),
keep_prob=1,
input_layout="TND",
actual_seq_qlen=actual_seq_len,
actual_seq_kvlen=actual_seq_len,
sparse_mode=3)[0].reshape(seq_length, bsz, -1)
else:
sparse_mode = self.config.sparse_mode
if attn_mask_type == AttnMaskType.no_mask:
sparse_mode = 0
scale = (
1.0 / math.sqrt(self.hidden_size_per_attention_head)
if self.scale_mask_softmax.scale is None
else self.softmax_scale
)
if packed_seq_params is not None:
if isinstance(packed_seq_params.cu_seqlens_q, list):
actual_seq_qlen = packed_seq_params.cu_seqlens_q
actual_seq_kvlen = packed_seq_params.cu_seqlens_kv
else:
actual_seq_qlen = packed_seq_params.cu_seqlens_q.tolist()
actual_seq_kvlen = packed_seq_params.cu_seqlens_kv.tolist()
shape_order = 'TND'
else:
actual_seq_qlen = None
actual_seq_kvlen = None
query, key, value = (
[
rearrange(x, 's b h d -> s b (h d)')
for x in [query, key, value]
]
)
shape_order = 'SBH'
output = torch_npu.npu_fusion_attention(
query, key, value,
n_head,
shape_order,
pse=None,
padding_mask=None,
atten_mask=attention_mask,
scale=scale,
pre_tockens=self.config.pre_tockens,
next_tockens=self.config.next_tockens,
keep_prob=1 - self.attention_dropout.p,
inner_precise=0,
sparse_mode=sparse_mode,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_kvlen
)[0]
return output