import torch
from mindspeed.op_builder.npu_sparse_attn_shared_kv_builder import NPUSparseAttnSharedKVOpBuilder
op_builder = NPUSparseAttnSharedKVOpBuilder()
class SparseAttnSharedKV(torch.autograd.Function):
@staticmethod
def forward(ctx, query, ori_kv, cmp_kv, cu_seq_lens_q, cu_seq_lens_ori_kv, cu_seq_lens_cmp_kv, ori_sparse_indices,
cmp_sparse_indices, sinks, softmax_scale, cmp_ratio, ori_mask_mode, cmp_mask_mode, ori_win_left,
ori_win_right, num_heads_q, num_heads_kv, head_dim, batch_size, max_seq_len_q, max_seq_len_kv, topk,
layout_q, layout_kv):
op = op_builder.load()
metadata = op.npu_sparse_attn_shared_kv_metadata(
cu_seq_lens_q if cu_seq_lens_q is not None else torch.tensor([]).npu(),
torch.tensor([]).npu(),
torch.tensor([]).npu(),
torch.tensor([]).npu(),
torch.tensor([]).npu(),
num_heads_q,
num_heads_kv,
head_dim,
batch_size,
max_seq_len_q,
max_seq_len_kv,
topk,
topk,
cmp_ratio,
ori_mask_mode,
cmp_mask_mode,
ori_win_left,
ori_win_right,
layout_q,
layout_kv,
ori_kv is not None,
cmp_kv is not None,
)
result, softmax_lse = op.npu_sparse_attn_shared_kv(
query,
ori_kv,
cmp_kv,
ori_sparse_indices,
cmp_sparse_indices,
None,
None,
cu_seq_lens_q,
cu_seq_lens_ori_kv,
cu_seq_lens_cmp_kv,
None,
None,
sinks,
metadata,
softmax_scale,
cmp_ratio,
ori_mask_mode,
cmp_mask_mode,
ori_win_left,
ori_win_right,
layout_q,
layout_kv,
True,
)
ctx.save_for_backward(query, ori_kv, cmp_kv, result, softmax_lse, ori_sparse_indices, cmp_sparse_indices,
cu_seq_lens_q, cu_seq_lens_ori_kv, cu_seq_lens_cmp_kv, sinks)
ctx.softmax_scale = softmax_scale
ctx.cmp_ratio = cmp_ratio
ctx.ori_mask_mode = ori_mask_mode
ctx.cmp_mask_mode = cmp_mask_mode
ctx.ori_win_left = ori_win_left
ctx.ori_win_right = ori_win_right
ctx.layout_q = layout_q
return result
@staticmethod
def backward(ctx, grad_output):
op = op_builder.load()
query, ori_kv, cmp_kv, result, softmax_lse, ori_sparse_indices, cmp_sparse_indices, cu_seq_lens_q, \
cu_seq_lens_ori_kv, cu_seq_lens_cmp_kv, sinks = ctx.saved_tensors
query_grad, ori_kv_grad, cmp_kv_grad, sinks_grad = op.npu_sparse_attn_shared_kv_grad(
query,
ori_kv,
cmp_kv,
grad_output,
result,
softmax_lse,
ori_sparse_indices,
cmp_sparse_indices,
cu_seq_lens_q,
cu_seq_lens_ori_kv,
cu_seq_lens_cmp_kv,
sinks,
ctx.softmax_scale,
ctx.cmp_ratio,
ctx.ori_mask_mode,
ctx.cmp_mask_mode,
ctx.ori_win_left,
ctx.ori_win_right,
ctx.layout_q
)
return query_grad, ori_kv_grad, cmp_kv_grad, None, None, None, None, None, sinks_grad, None, None, None, None, \
None, None, None, None, None, None, None, None, None, None, None
def npu_sparse_attn_shared_kv(query, ori_kv, cmp_kv, cmp_sparse_indices, sinks, softmax_scale, cmp_ratio,
ori_mask_mode=4, cmp_mask_mode=3, ori_win_left=127, ori_win_right=0):
cu_seq_lens_q = cu_seq_lens_ori_kv = cu_seq_lens_cmp_kv = None
ori_sparse_indices = None
max_seq_len_q, batch_size, num_heads_q, head_dim = query.size()
num_heads_kv = 1
max_seq_len_kv = ori_kv.size(0)
topk = 0 if cmp_ratio != 4 else cmp_sparse_indices.size(-1)
layout_q = layout_kv = 'BSND'
query = query.permute(1, 0, 2, 3).contiguous()
ori_kv = ori_kv.permute(1, 0, 2).unsqueeze(2).contiguous()
cmp_kv = cmp_kv if cmp_kv is None else cmp_kv.permute(1, 0, 2).unsqueeze(2).contiguous()
cmp_sparse_indices = None if cmp_ratio != 4 else cmp_sparse_indices.unsqueeze(2).contiguous()
output = SparseAttnSharedKV.apply(query, ori_kv, cmp_kv, cu_seq_lens_q, cu_seq_lens_ori_kv, cu_seq_lens_cmp_kv,
ori_sparse_indices, cmp_sparse_indices, sinks, softmax_scale, cmp_ratio,
ori_mask_mode, cmp_mask_mode, ori_win_left, ori_win_right, num_heads_q,
num_heads_kv, head_dim, batch_size, max_seq_len_q, max_seq_len_kv, topk, layout_q,
layout_kv)
return output.transpose(0, 1).contiguous()