import torch
import torch_npu
from megatron.core.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax:
@staticmethod
def apply(input_, scale):
size = input_.size()
input_ = input_.view((1,) + tuple(size)).contiguous()
dummy_mask = torch.zeros(input_.size()).bool().npu()
output = torch_npu.npu_scaled_masked_softmax(input_, dummy_mask, scale, True)
return output.view(size).contiguous()
class ScaledMaskedSoftmax:
@staticmethod
def apply(input_, mask, scale):
return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, False)
class ScaledSoftmax:
@staticmethod
def apply(input_, scale):
dummy_mask = torch.zeros(input_.size()).bool().npu()
return torch_npu.npu_scaled_masked_softmax(input_, dummy_mask, scale, False)
def is_kernel_available(self, mask, b, np, sq, sk):
return (
self.scaled_masked_softmax_fusion
and self.input_in_float16
and 32 < sk <= 4096
and sq % 16 == 0
and sk % 16 == 0
)
def forward_fused_softmax(self, input_, mask):
b, np, sq, sk = input_.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, True)
else:
if mask is not None:
return torch_npu.npu_scaled_masked_softmax(input_, mask, scale, False)
else:
return ScaledSoftmax.apply(input_, scale)