import torch
from mindspeed.op_builder.npu_lightning_indexer_builder import NPULightningIndexerOpBuilder
__all__ = ["npu_lightning_indexer"]
op_builder = NPULightningIndexerOpBuilder()
class LightningIndexer(torch.autograd.Function):
@staticmethod
def forward(ctx, query, key, weights, cu_seq_lens_q, cu_seq_lens_k, layout,
sparse_count, sparse_mode, pre_tokens, next_tokens, cmp_ratio):
op = op_builder.load()
sparse_indices, sparse_values = op.npu_lightning_indexer(
query,
key,
weights,
cu_seq_lens_q,
cu_seq_lens_k,
None,
layout,
layout,
sparse_count,
sparse_mode,
pre_tokens,
next_tokens,
cmp_ratio,
True,
)
ctx.save_for_backward(query, key, weights, cu_seq_lens_q, cu_seq_lens_k, sparse_indices)
ctx.layout = layout
ctx.sparse_mode = sparse_mode
ctx.pre_tokens = pre_tokens
ctx.next_tokens = next_tokens
ctx.cmp_ratio = cmp_ratio
return sparse_indices, sparse_values
@staticmethod
def backward(ctx, _, grad_output):
op = op_builder.load()
query, key, weights, cu_seq_lens_q, cu_seq_lens_k, sparse_indices = ctx.saved_tensors
query_grad, k_grad, weights_grad = op.npu_lightning_indexer_grad(
query,
key,
grad_output,
sparse_indices,
weights,
cu_seq_lens_q,
cu_seq_lens_k,
ctx.layout,
ctx.sparse_mode,
ctx.pre_tokens,
ctx.next_tokens,
ctx.cmp_ratio
)
return query_grad, k_grad, weights_grad, None, None, None, None, None, None, None, None
def npu_lightning_indexer(query, key, weights,
layout="BSND", cu_seq_lens_q=None, cu_seq_lens_k=None,
sparse_count=2048, sparse_mode=3,
pre_tokens=2**63-1, next_tokens=2**63-1, cmp_ratio=1):
cu_seq_lens_q = cu_seq_lens_k = None
return LightningIndexer.apply(query, key, weights, cu_seq_lens_q, cu_seq_lens_k, layout,
sparse_count, sparse_mode, pre_tokens, next_tokens, cmp_ratio)