Select Attention Operators
High performance Ascend-910B kernels for sparse attention pattern prediction for efficient LLM decoding. The kernels can be launched through python interface which we provide - see kernel usage examples in the experiments directory.
Repo structure
.
|-- experiments - per kernel: test (functional correctness) and benchmark (time and bandwidth)
| |-- 2_quest_prefill_metadata - constructing metadata after prefill
| |-- 3_quest_block_select_paged - quest sparse mask predictor using metadata
| |-- 4_quest_block_select_paged_w - quest sparse mask predictor using metadata with extra sink+window features
|-- kernels - python packages, each having one or more ascendc kernels and a single torch interface
| |-- select_attn_ops - predictor kernels (quest predictors of sparse pattern duting LLM decoding)
`-- scripts
|-- build_kernels.sh - builds all kernels
`-- init_cann.sh - initialize the environment and Ascend device version
Requirements
Tested to work with:
- Ascend910B2, Ascend910B4
- CANN versions 8.0.RC3.beta1, 8.2.RC2, 8.3.RC1
- Python 3.11.10
- torch-npu version 2.4.0, 2.5.1.post1
- see requirements.txt for all other requirements
Creating conda environment
create conda environment
conda create -n sa python=3.11.10 -y
conda activate sa
pip install -r requirements.txt
Running (in conda environment)
activate conda and CANN environment, compile the operators and build their python api (as python packages):
source scripts/init_cann.sh Ascend910B4 # change Ascend910B4 to your card model
bash scripts/build_kernels.sh
Run all tests that are found in the experiments subdirectory
pytest -v experiments
Usage
Current best practise (proven at vllm-ascend) is to use quest_prefill_metadata() kernel for the creation of the metadata (after prefill) and every 128 tokens to update the metadata, and to use quest_block_select_paged_in_out_w() to predict important KV block indices given the current query vector of the token being decoded. For detailed usage examples refer to experiments directory.
The kernels are deployed with a very neat built in documentation:
import torch_npu
from select_attn_ops import quest_block_select_paged_in_out_w
help(quest_block_select_paged_in_out_w)
Prints:
Help on built-in function quest_block_select_paged_in_out_w in module select_attn_ops:
quest_block_select_paged_in_out_w(...) method of builtins.PyCapsule instance
quest_block_select_paged_in_out_w(query: torch.Tensor, maxblocks: torch.Tensor, minblocks: torch.Tensor, metadata_block_tables: torch.Tensor, seq_lens: torch.Tensor, tokens_since_metadata_update: int, selected_indices: torch.Tensor) -> None
Alternative interface to the `quest_block_select_paged` kernel which predicts
the sparsity mask during decoding in the form of top-k important kv-block
indices for every KV-head in every request. The returned KV block ids
are not the indices in the KV-cache, but rather from their enumeration
from 0 to number of blocks in the sequence length being decoded.
FEATURE 1) WITH PREALLOCATED OUTPUT TENSOR (selected_indices)
FEATURE 2) "w" 2 stands for "window" i.e. the kernel decides whether to add local
window blocks ids to the selected indices based on the number of tokens
since the last update and based on the sequence length
Args:
query (torch.Tensor): Query vector of shape [B, H, D] (fp16 or bf16)
maxblocks (torch.Tensor): Quest metadata with maximum vectors of
every key-cache block of shape
[num_meta_blocks, BLOCK_SIZE, N, D] (fp16 or bf16)
important: zeroes must be in place of metadata of non-existing kv blocks
minblocks (torch.Tensor): Quest metadata with minimum vectors of
every key-cache block of shape
[num_meta_blocks, BLOCK_SIZE, N, D] (fp16 or bf16)
important: zeroes must be in place of metadata of non-existing kv blocks
metadata_block_tables (torch.Tensor): Metadata block tables of
shape [B, MMBPR] (int32)
seq_lens (torch.Tensor): Sequence length of each request in the batch
of shape [B] (int32)
tokens_since_metadata_update (int) - number of tokens that were decoded
since the last metadata update (note metadata update is
done only on the multiple of BLOCK_SIZE tokens which is
lower or equal to the sequence length at the moment of update)
set to -1 to disable selection of KV blocks for which the
metadata doesn't exist.
selected_indices (torch.Tensor): Selected indices vector of shape [B, N, k] (int32):
Number of highest indices to return for every KV head
Returns:
<fills out the selected_indices tensor>
Limitations: due to kernel's internal buffer design on 910B:
D = 128
BLOCK_SIZE = 128
H / N <= BLOCK_SIZE
MMBPR <= 6
k % 8 == 0
Development Workflow for a new kernel "OP"
- Add new kernel implementations in the
kernels/directory in one of 2 ways:- under an existing python package e.g.
kernels/select_attn_ops/. Then add your kernel code as new OP.cpp, add a compilation line to compile.sh, add a torch interface inside torch_interface.cpp - as a new python package:
kernels/OP/, with a OP.cpp kernel implementation; torch_interface.cpp, compile.sh, build.sh in it.
- under an existing python package e.g.
- Create a dedicated experiment directory
experiments/5_OPand implement in it the following programs:- ref_OP.py - start off by implementing a reference python model for correctness.
- gen_data_OP.py - a function that produces a set of input tensors for your kernel.
- test_OP.py with a smoke test (single run first) to validate correctness on a focused single input, then extend to automated pytesting across wide range of input shapes/data-types
- benchmark_OP.py - measure performance (time, bandwidth)