import torch
from mindspeed.op_builder import FusionAttentionV2OpBuilder
__all__ = ["npu_fusion_attention"]
class FusionAttentionV2Function(torch.autograd.Function):
@staticmethod
def forward(ctx, query, key, value, head_num, input_layout, pse, padding_mask, atten_mask, scale, keep_prob,
pre_tokens, next_tokens, inner_precise, prefix, actual_seq_qlen, actual_seq_kvlen, sparse_mode,
gen_mask_parallel, sync, pse_type, q_start_idx, kv_start_idx):
mindspeed_ops = FusionAttentionV2OpBuilder().load()
outputs = mindspeed_ops.npu_fusion_attention_v2(query, key, value, head_num,
input_layout, pse,
padding_mask, atten_mask,
scale, keep_prob, pre_tokens,
next_tokens, inner_precise, prefix,
actual_seq_qlen, actual_seq_kvlen,
sparse_mode, gen_mask_parallel,
sync, pse_type, q_start_idx,
kv_start_idx)
attention_in, softmax_max, softmax_sum, softmax_in, seed, offset, numels = outputs
ctx.save_for_backward(query, key, value, pse, padding_mask, atten_mask, attention_in,
softmax_max, softmax_sum, softmax_in)
ctx.scale = scale
ctx.input_layout = input_layout
ctx.head_num = head_num
ctx.pre_tokens = pre_tokens
ctx.next_tokens = next_tokens
ctx.inner_precise = inner_precise
ctx.gen_mask_parallel = gen_mask_parallel
ctx.sync = sync
ctx.seed = seed
ctx.offset = offset
ctx.numels = numels
ctx.prefix = prefix
ctx.keep_prob = keep_prob
ctx.actual_seq_qlen = actual_seq_qlen
ctx.actual_seq_kvlen = actual_seq_kvlen
ctx.sparse_mode = sparse_mode
ctx.pse_type = pse_type
ctx.q_start_idx = q_start_idx
ctx.kv_start_idx = kv_start_idx
return outputs
@staticmethod
def backward(ctx, grad_outputs, dq=None, dk=None, dv=None, seed=0, offset=0, numels=0):
mindspeed_ops = FusionAttentionV2OpBuilder().load()
query, key, value, pse, padding_mask, atten_mask, attention_in, softmax_max, \
softmax_sum, softmax_in = ctx.saved_tensors
results = mindspeed_ops.npu_fusion_attention_grad_v2(
query, key, value, grad_outputs, ctx.head_num, ctx.input_layout, pse, padding_mask, atten_mask,
softmax_max, softmax_sum, softmax_in, attention_in, ctx.scale, ctx.keep_prob, ctx.pre_tokens,
ctx.next_tokens, ctx.inner_precise, ctx.seed, ctx.offset, ctx.numels, ctx.prefix, ctx.actual_seq_qlen,
ctx.actual_seq_kvlen, ctx.sparse_mode, ctx.gen_mask_parallel, ctx.sync, ctx.pse_type, ctx.q_start_idx,
ctx.kv_start_idx)
return results[0], results[1], results[2], None, None, results[3], None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def npu_fusion_attention(query, key, value, head_num,
input_layout, *, pse=None,
padding_mask=None, atten_mask=None,
scale=1., keep_prob=1., pre_tokens=2147483647,
next_tokens=2147483647, inner_precise=0, prefix=None,
actual_seq_qlen=None, actual_seq_kvlen=None,
sparse_mode=0, gen_mask_parallel=True,
sync=False, pse_type=1, q_start_idx=None,
kv_start_idx=None):
return FusionAttentionV2Function.apply(query, key, value, head_num,
input_layout, pse,
padding_mask, atten_mask,
scale, keep_prob, pre_tokens,
next_tokens, inner_precise, prefix,
actual_seq_qlen, actual_seq_kvlen,
sparse_mode, gen_mask_parallel,
sync, pse_type, q_start_idx,
kv_start_idx)
def npu_fusion_attention_grad(query, key, value, grad_outputs,
head_num, input_layout, *, pse=None,
padding_mask=None, atten_mask=None,
softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None,
scale=1., keep_prob=1., pre_tokens=2147483647,
next_tokens=2147483647, inner_precise=0,
seed=1234, offset=0, numels=0, prefix=None,
actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None,
kv_start_idx=None):
mindspeed_ops = FusionAttentionV2OpBuilder().load()
return mindspeed_ops.npu_fusion_attention_grad_v2(query, key, value, grad_outputs, head_num, input_layout, pse,
padding_mask, atten_mask, softmax_max, softmax_sum, softmax_in,
attention_in, scale, keep_prob, pre_tokens, next_tokens,
inner_precise, seed, offset, numels, prefix, actual_seq_qlen,
actual_seq_kvlen, sparse_mode, gen_mask_parallel, sync,
pse_type, q_start_idx, kv_start_idx)