How to use the benchmarking/correctness program benchmark.py
Installation
To install the sparse kernel go in the ops-transformer project home folder and run:
bash build.sh --make_clean --experimental -j96 --pkg --soc=ascend910b --ops=blitz_sparse_attention
./build/cann-ops-transformer-custom_linux-"$(uname -i)".run
(cd experimental/attention/blitz_sparse_attention/torch_interface && bash build.sh custom)
Fast recompilation only of bf16 kernel variant
bash build.sh --make_clean --experimental -j96 --pkg --soc=ascend910b --ops=blitz_sparse_attention --op-variant=2
./build/cann-ops-transformer-custom_linux-"$(uname -i)".run
Oneliner smoketest and benchmark
(cd experimental/attention/blitz_sparse_attention/benchmark && python test_attn.py && python benchmark.py)
Benchmark + plot pipeline (writes benchmark.png summarising all
(BLOCK_SIZE_Q, BLOCK_SIZE_KV, sparsity) combinations):
(cd experimental/attention/blitz_sparse_attention/benchmark && python benchmark.py | tee bench.log /dev/tty | python plot.py)
Parameters
Set the ALL_CAPS constants at the top of the file to control what gets run. Wherever a constant is a list, the sweep covers every combination of all listed values.
B_VALS: batch sizes to test. For block-sparsity onlyB_VALS = [1]works for now.H_VALS: number of headsS_VALS: sequence lengthsD_VALS: head dimensionsN_REPEATS: how many runs to do to estimate timeN_WARMUP: how many warmup runs before runningN_REPEATScallsSPARSITY_VALS: how many blocks to activate: randomly chosen, but each Q-block row will have the same number of blocks. The block-sparse pattern is the only kernel mode the benchmark exercises.BLOCK_SHAPES: List of(BLOCK_SIZE_Q, BLOCK_SIZE_KV)pairs to sweep. The kernel accepts every combination ofBLOCK_SIZE_Q ∈ {128, 256, 512, 1024}andBLOCK_SIZE_KV ∈ {128, 256, 512, 1024}at runtime via theblock_shapeop-attr — no rebuild needed to switch granularities. The default sweep is the full 4 × 4 = 16-combination grid. The first column of the benchmark output labels every row with its active block shape.BLOCK_SIZE_QandBLOCK_SIZE_KValiases (set toBLOCK_SHAPES[0]for backward compatibility) are still readable from the top ofbenchmark.py.BLOCK_MASK_SEED: for reproducibility of random sampling. Can set to any value to change random samplingFRAMES_BY_BLOCK_SHAPE: 2-D dict keyed by(BLOCK_SIZE_Q, BLOCK_SIZE_KV)returning aFrameWidths(left_cols, right_cols, top_rows, bottom_rows)namedtuple — extra blocks forced active on top of the sparsity budget. Models a typical sink/streaming pattern. Per-pair entries scale to keep the forced-token footprint comparable across granularities; see Example 2 for the derivation.PRINT_OUTPUTS: For manual visualization of output tensorsPRINT_MASK: To visually check if the mask is as expectedPRINT_BLOCK_EQUALITY: if the correctness test fails, you can check which blocks in the outputs match with the reference and which not. The block granularity to visualize is given byPRINT_HEIGHTandPRINT_WIDTHRUN_REFERENCE: Besides measuring time, also test correctness. This only is possible for small enough masks, otherwise you will get OOM in the reference run. For this reason, you can disable this for very long contexts, and enable it for short enough sequences.TORCH_REFERENCE: IfTrue, correctness will be tested against the torch attention implementation. IfFalse, for block-sparse attention the reference is the torch_npu dense with the corresponding attention mask (like insparse_blockmode). Please only useH=1, as torch_npu doesn't support different masks for different heads.
Examples
1 - Test correctness and speed with multiple values (and shorter sequence lengths)
Input
Adjust the following values in the ===== Begin of Parameter Sweep Definitions ==== section
DTYPE = torch.bfloat16
INPUT_LAYOUT = "BNSD" # [batch_size, num_hea
B_VALS = [1]
H_VALS = [16, 14]
S_VALS = [4096, 10_000] # S_q = S_kv
D_VALS = [128] # head dimension
N_REPEATS = 20
N_WARMUP = 2
SPARSITY_VALS = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
BLOCK_SHAPES = [(128, 512)]
SABI_SORTED = True
FRAMES_BY_BLOCK_SHAPE = {(128, 512):None} # no frame forced to be selected
PRINT_OUTPUTS = False
PRINT_MASK = False
PRINT_BLOCK_EQUALITY = False
PRINT_HEIGHT = 128
PRINT_WIDTH = 8
RUN_REFERENCE = True
TORCH_REFERENCE = "npu_fusion_attention"
Output:
========================================================================================================================
DTYPE=torch.bfloat16 INPUT_LAYOUT='BNSD' SABI_SORTED=True TORCH_REFERENCE='npu_fusion_attention'
========================================================================================================================
block_shape H B s_q s_kv D frame(L,R,T,B) sparsity Outputs_equal Ref_Latency_[usec] Our_Latency_[usec]
------------------------------------------------------------------------------------------------------------------------
128x512 16 1 4096 4096 128 - 0.00 yes 859.30 1619.89
128x512 16 1 4096 4096 128 - 0.05 yes 1000.11 1617.22
128x512 16 1 4096 4096 128 - 0.10 yes 997.50 1593.49
128x512 16 1 4096 4096 128 - 0.20 yes 997.37 1457.39
128x512 16 1 4096 4096 128 - 0.30 yes 997.62 1359.71
128x512 16 1 4096 4096 128 - 0.40 yes 997.41 1221.68
128x512 16 1 4096 4096 128 - 0.50 yes 997.88 1114.68
128x512 16 1 4096 4096 128 - 0.60 yes 996.95 1281.74
128x512 16 1 4096 4096 128 - 0.70 yes 997.40 986.38
128x512 16 1 4096 4096 128 - 0.80 yes 997.62 875.81
128x512 16 1 4096 4096 128 - 0.90 yes 997.56 752.91
128x512 16 1 10000 10000 128 - 0.00 yes 5696.26 6649.24
128x512 16 1 10000 10000 128 - 0.05 yes 7533.60 6319.64
128x512 16 1 10000 10000 128 - 0.10 yes 7498.81 6310.39
128x512 16 1 10000 10000 128 - 0.20 yes 7574.27 5369.41
128x512 16 1 10000 10000 128 - 0.30 yes 7520.82 4952.99
128x512 16 1 10000 10000 128 - 0.40 yes 7572.14 4159.39
128x512 16 1 10000 10000 128 - 0.50 yes 7538.75 3560.27
128x512 16 1 10000 10000 128 - 0.60 yes 7561.45 3079.28
128x512 16 1 10000 10000 128 - 0.70 yes 7482.04 2406.52
128x512 16 1 10000 10000 128 - 0.80 yes 7565.93 1857.96
128x512 16 1 10000 10000 128 - 0.90 yes 7533.11 1266.50
128x512 14 1 4096 4096 128 - 0.00 yes 784.19 1534.82
128x512 14 1 4096 4096 128 - 0.05 yes 906.90 1693.06
128x512 14 1 4096 4096 128 - 0.10 yes 908.12 1515.25
128x512 14 1 4096 4096 128 - 0.20 yes 905.51 1398.79
128x512 14 1 4096 4096 128 - 0.30 yes 908.47 1287.53
128x512 14 1 4096 4096 128 - 0.40 yes 906.55 1173.67
128x512 14 1 4096 4096 128 - 0.50 yes 907.90 1064.68
128x512 14 1 4096 4096 128 - 0.60 yes 906.96 1065.85
128x512 14 1 4096 4096 128 - 0.70 yes 908.62 946.82
128x512 14 1 4096 4096 128 - 0.80 yes 907.40 869.84
128x512 14 1 4096 4096 128 - 0.90 yes 908.16 723.44
128x512 14 1 10000 10000 128 - 0.00 yes 4857.91 5930.11
128x512 14 1 10000 10000 128 - 0.05 yes 6695.51 5614.66
128x512 14 1 10000 10000 128 - 0.10 yes 6680.30 5312.90
128x512 14 1 10000 10000 128 - 0.20 yes 6621.15 4772.31
128x512 14 1 10000 10000 128 - 0.30 yes 6690.78 5062.18
128x512 14 1 10000 10000 128 - 0.40 yes 6622.58 3757.85
128x512 14 1 10000 10000 128 - 0.50 yes 6661.94 3208.64
128x512 14 1 10000 10000 128 - 0.60 yes 6630.15 2696.30
128x512 14 1 10000 10000 128 - 0.70 yes 6641.87 2193.39
128x512 14 1 10000 10000 128 - 0.80 yes 6672.09 1715.01
128x512 14 1 10000 10000 128 - 0.90 yes 6685.34 1186.83
========================================================================================================================
Reference goes OOM for 10_000 sequence length and 24 heads (not shown).
On the left you see the inputs, then the "yes" line checks correctness, after that you see the reference runtime and our runtime. After that you see the memory bandwidth usage, which is much better in our kernel.
Note that for short sequence lengths, the speedups are negligible because overheads are higher than actual computation.
2 - block_shape granularity sweep (S=118806, H=3, D=128, BF16)
The runtime block_shape attr lets the same kernel run at every combination of
(BLOCK_SIZE_Q, BLOCK_SIZE_KV) with each dimension in {128, 256, 512, 1024}
— 16 granularities sharing one kernel binary. Finer sabi (smaller block
sizes) buys more sparsity-pattern resolution per Q-row at the cost of more
sabi-tensor HBM traffic + more sub-block plumbing per cube tile; coarser sabi
buys cheaper per-row metadata at the cost of resolution. The sweep below
measures the resulting latency/throughput trade-off at S = 118 806.
Why (left_cols, right_cols, top_rows, bottom_rows) scales as (29,15,8,4) / (29,15,8,4)
The benchmark forces a frame in the attention matrix that models a typical generative-model sink/streaming pattern: the first ~3600 attention rows and columns plus the last ~6 rows and columns are always selected. 6 tokens is less than a single block in every dimension, so it always rounds up to one full block-row / block-column at the bottom/right.
FrameWidths.left_cols / right_cols are in BLOCK_SIZE_KV-token units, and
top_rows / bottom_rows are in BLOCK_SIZE_Q-token units, so the per-pair
values fall out of rounding the constant 3600-token target up to whole blocks
in each dimension independently:
| size | ceil(3600 / size) |
|---|---|
| 128 | 29 |
| 256 | 15 |
| 512 | 8 |
| 1024 | 4 |
That gives a 2-D FRAMES_BY_BLOCK_SHAPE[(BLOCK_SIZE_Q, BLOCK_SIZE_KV)] table
where left_cols ∈ {29, 15, 8, 4} (indexed by BLOCK_SIZE_KV), top_rows ∈ {29, 15, 8, 4} (indexed by BLOCK_SIZE_Q), and right_cols = bottom_rows = 1
everywhere. e.g. (BLOCK_SIZE_Q=1024, BLOCK_SIZE_KV=1024) ⇒
FrameWidths(left_cols=4, right_cols=1, top_rows=4, bottom_rows=1).
The test files apply the same scaling rule with a smaller target (~256 tokens)
for their SPARSE_FRAME.
Input
B_VALS = [1]; H_VALS = [3]; S_VALS = [118_806]; D_VALS = [128]
N_REPEATS = 10; N_WARMUP = 2
SPARSITY_VALS = [0.0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
BLOCK_SHAPES = [(bsq, bskv)
for bsq in (128, 256, 512, 1024)
for bskv in (128, 256, 512, 1024)] # 4 × 4 = 16 pairs
# Per-pair frame: left_cols/top_rows are scaled by 128 / size in each dim so
# the forced-token footprint stays comparable across block shapes.
FRAMES_BY_BLOCK_SHAPE = {
(bsq, bskv): FrameWidths(
left_cols = {128: 29, 256: 15, 512: 8, 1024: 4}[bskv],
right_cols = 1,
top_rows = {128: 29, 256: 15, 512: 8, 1024: 4}[bsq],
bottom_rows = 1,
)
for bsq in (128, 256, 512, 1024)
for bskv in (128, 256, 512, 1024)
}
TORCH_REFERENCE = "npu_fusion_attention"
Output
Run python benchmark.py | tee bench.log /dev/tty | python plot.py to also generate
a speedup-vs-sparsity figure with one curve per (sequence_length, block_shape)
pair. The full raw table below was captured at S=118 806, H=3, D=128, BF16:
Full raw benchmark table (16 block shapes × 11 sparsities — click to expand)
========================================================================================================================
DTYPE=torch.bfloat16 INPUT_LAYOUT='BNSD' SABI_SORTED=True TORCH_REFERENCE='npu_fusion_attention'
========================================================================================================================
block_shape H B s_q s_kv D frame(L,R,T,B) sparsity Outputs_equal Ref_Latency_[usec] Our_Latency_[usec]
------------------------------------------------------------------------------------------------------------------------
128x128 3 1 118806 118806 128 - 0.00 yes 157210.33 200781.80
128x128 3 1 118806 118806 128 (29,1,29,1) 0.05 N/A N/A 161036.06
128x128 3 1 118806 118806 128 (29,1,29,1) 0.10 N/A N/A 158057.97
128x128 3 1 118806 118806 128 (29,1,29,1) 0.20 N/A N/A 152368.05
128x128 3 1 118806 118806 128 (29,1,29,1) 0.30 N/A N/A 144510.33
128x128 3 1 118806 118806 128 (29,1,29,1) 0.40 N/A N/A 133058.51
128x128 3 1 118806 118806 128 (29,1,29,1) 0.50 N/A N/A 117734.07
128x128 3 1 118806 118806 128 (29,1,29,1) 0.60 N/A N/A 99123.58
128x128 3 1 118806 118806 128 (29,1,29,1) 0.70 N/A N/A 78038.51
128x128 3 1 118806 118806 128 (29,1,29,1) 0.80 N/A N/A 55330.05
128x128 3 1 118806 118806 128 (29,1,29,1) 0.90 N/A N/A 24609.61
128x256 3 1 118806 118806 128 - 0.00 yes 159207.32 199820.03
128x256 3 1 118806 118806 128 (15,1,29,1) 0.05 N/A N/A 158132.10
128x256 3 1 118806 118806 128 (15,1,29,1) 0.10 N/A N/A 151330.90
128x256 3 1 118806 118806 128 (15,1,29,1) 0.20 N/A N/A 137840.45
128x256 3 1 118806 118806 128 (15,1,29,1) 0.30 N/A N/A 123943.16
128x256 3 1 118806 118806 128 (15,1,29,1) 0.40 N/A N/A 107830.05
128x256 3 1 118806 118806 128 (15,1,29,1) 0.50 N/A N/A 91377.56
128x256 3 1 118806 118806 128 (15,1,29,1) 0.60 N/A N/A 74124.60
128x256 3 1 118806 118806 128 (15,1,29,1) 0.70 N/A N/A 56262.56
128x256 3 1 118806 118806 128 (15,1,29,1) 0.80 N/A N/A 37968.08
128x256 3 1 118806 118806 128 (15,1,29,1) 0.90 N/A N/A 19270.95
128x512 3 1 118806 118806 128 - 0.00 yes 161129.31 198749.26
128x512 3 1 118806 118806 128 (8,1,29,1) 0.05 N/A N/A 173208.98
128x512 3 1 118806 118806 128 (8,1,29,1) 0.10 N/A N/A 160075.51
128x512 3 1 118806 118806 128 (8,1,29,1) 0.20 N/A N/A 137263.96
128x512 3 1 118806 118806 128 (8,1,29,1) 0.30 N/A N/A 119119.91
128x512 3 1 118806 118806 128 (8,1,29,1) 0.40 N/A N/A 102260.27
128x512 3 1 118806 118806 128 (8,1,29,1) 0.50 N/A N/A 85403.20
128x512 3 1 118806 118806 128 (8,1,29,1) 0.60 N/A N/A 68388.53
128x512 3 1 118806 118806 128 (8,1,29,1) 0.70 N/A N/A 51450.43
128x512 3 1 118806 118806 128 (8,1,29,1) 0.80 N/A N/A 34669.06
128x512 3 1 118806 118806 128 (8,1,29,1) 0.90 N/A N/A 17906.81
128x1024 3 1 118806 118806 128 - 0.00 yes 162207.12 202698.22
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.05 N/A N/A 174917.09
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.10 N/A N/A 165089.04
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.20 N/A N/A 145535.10
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.30 N/A N/A 125322.22
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.40 N/A N/A 105297.35
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.50 N/A N/A 87966.00
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.60 N/A N/A 69510.69
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.70 N/A N/A 52431.44
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.80 N/A N/A 35000.96
128x1024 3 1 118806 118806 128 (4,1,29,1) 0.90 N/A N/A 18146.70
256x128 3 1 118806 118806 128 - 0.00 yes 163067.77 200198.46
256x128 3 1 118806 118806 128 (29,1,15,1) 0.05 N/A N/A 165123.96
256x128 3 1 118806 118806 128 (29,1,15,1) 0.10 N/A N/A 162855.57
256x128 3 1 118806 118806 128 (29,1,15,1) 0.20 N/A N/A 153739.60
256x128 3 1 118806 118806 128 (29,1,15,1) 0.30 N/A N/A 145058.70
256x128 3 1 118806 118806 128 (29,1,15,1) 0.40 N/A N/A 133083.78
256x128 3 1 118806 118806 128 (29,1,15,1) 0.50 N/A N/A 117827.78
256x128 3 1 118806 118806 128 (29,1,15,1) 0.60 N/A N/A 99380.31
256x128 3 1 118806 118806 128 (29,1,15,1) 0.70 N/A N/A 78359.30
256x128 3 1 118806 118806 128 (29,1,15,1) 0.80 N/A N/A 55515.84
256x128 3 1 118806 118806 128 (29,1,15,1) 0.90 N/A N/A 24680.32
256x256 3 1 118806 118806 128 - 0.00 yes 163126.33 200421.86
256x256 3 1 118806 118806 128 (15,1,15,1) 0.05 N/A N/A 162244.29
256x256 3 1 118806 118806 128 (15,1,15,1) 0.10 N/A N/A 154796.35
256x256 3 1 118806 118806 128 (15,1,15,1) 0.20 N/A N/A 141090.41
256x256 3 1 118806 118806 128 (15,1,15,1) 0.30 N/A N/A 125764.73
256x256 3 1 118806 118806 128 (15,1,15,1) 0.40 N/A N/A 109879.25
256x256 3 1 118806 118806 128 (15,1,15,1) 0.50 N/A N/A 93020.79
256x256 3 1 118806 118806 128 (15,1,15,1) 0.60 N/A N/A 75177.50
256x256 3 1 118806 118806 128 (15,1,15,1) 0.70 N/A N/A 56512.46
256x256 3 1 118806 118806 128 (15,1,15,1) 0.80 N/A N/A 37968.77
256x256 3 1 118806 118806 128 (15,1,15,1) 0.90 N/A N/A 19722.16
256x512 3 1 118806 118806 128 - 0.00 yes 164016.82 202325.94
256x512 3 1 118806 118806 128 (8,1,15,1) 0.05 N/A N/A 175537.23
256x512 3 1 118806 118806 128 (8,1,15,1) 0.10 N/A N/A 161875.20
256x512 3 1 118806 118806 128 (8,1,15,1) 0.20 N/A N/A 144422.60
256x512 3 1 118806 118806 128 (8,1,15,1) 0.30 N/A N/A 123083.90
256x512 3 1 118806 118806 128 (8,1,15,1) 0.40 N/A N/A 105284.77
256x512 3 1 118806 118806 128 (8,1,15,1) 0.50 N/A N/A 88914.20
256x512 3 1 118806 118806 128 (8,1,15,1) 0.60 N/A N/A 71715.88
256x512 3 1 118806 118806 128 (8,1,15,1) 0.70 N/A N/A 53809.10
256x512 3 1 118806 118806 128 (8,1,15,1) 0.80 N/A N/A 34646.38
256x512 3 1 118806 118806 128 (8,1,15,1) 0.90 N/A N/A 17254.83
256x1024 3 1 118806 118806 128 - 0.00 yes 164625.84 203211.69
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.05 N/A N/A 176233.89
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.10 N/A N/A 166313.94
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.20 N/A N/A 148188.48
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.30 N/A N/A 128503.65
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.40 N/A N/A 109221.57
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.50 N/A N/A 91500.32
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.60 N/A N/A 73028.33
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.70 N/A N/A 54234.94
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.80 N/A N/A 35525.95
256x1024 3 1 118806 118806 128 (4,1,15,1) 0.90 N/A N/A 16849.22
512x128 3 1 118806 118806 128 - 0.00 yes 165319.87 201165.34
512x128 3 1 118806 118806 128 (29,1,8,1) 0.05 N/A N/A 167500.90
512x128 3 1 118806 118806 128 (29,1,8,1) 0.10 N/A N/A 163980.60
512x128 3 1 118806 118806 128 (29,1,8,1) 0.20 N/A N/A 155223.03
512x128 3 1 118806 118806 128 (29,1,8,1) 0.30 N/A N/A 145277.70
512x128 3 1 118806 118806 128 (29,1,8,1) 0.40 N/A N/A 133691.13
512x128 3 1 118806 118806 128 (29,1,8,1) 0.50 N/A N/A 118228.98
512x128 3 1 118806 118806 128 (29,1,8,1) 0.60 N/A N/A 99471.91
512x128 3 1 118806 118806 128 (29,1,8,1) 0.70 N/A N/A 78237.73
512x128 3 1 118806 118806 128 (29,1,8,1) 0.80 N/A N/A 55266.32
512x128 3 1 118806 118806 128 (29,1,8,1) 0.90 N/A N/A 24597.66
512x256 3 1 118806 118806 128 - 0.00 yes 165073.32 201308.03
512x256 3 1 118806 118806 128 (15,1,8,1) 0.05 N/A N/A 165942.18
512x256 3 1 118806 118806 128 (15,1,8,1) 0.10 N/A N/A 157005.85
512x256 3 1 118806 118806 128 (15,1,8,1) 0.20 N/A N/A 142381.26
512x256 3 1 118806 118806 128 (15,1,8,1) 0.30 N/A N/A 126352.48
512x256 3 1 118806 118806 128 (15,1,8,1) 0.40 N/A N/A 110357.71
512x256 3 1 118806 118806 128 (15,1,8,1) 0.50 N/A N/A 93363.94
512x256 3 1 118806 118806 128 (15,1,8,1) 0.60 N/A N/A 75421.60
512x256 3 1 118806 118806 128 (15,1,8,1) 0.70 N/A N/A 57303.78
512x256 3 1 118806 118806 128 (15,1,8,1) 0.80 N/A N/A 39200.84
512x256 3 1 118806 118806 128 (15,1,8,1) 0.90 N/A N/A 19915.82
512x512 3 1 118806 118806 128 - 0.00 yes 165493.05 202528.27
512x512 3 1 118806 118806 128 (8,1,8,1) 0.05 N/A N/A 178655.80
512x512 3 1 118806 118806 128 (8,1,8,1) 0.10 N/A N/A 167690.99
512x512 3 1 118806 118806 128 (8,1,8,1) 0.20 N/A N/A 148863.20
512x512 3 1 118806 118806 128 (8,1,8,1) 0.30 N/A N/A 131446.69
512x512 3 1 118806 118806 128 (8,1,8,1) 0.40 N/A N/A 114359.11
512x512 3 1 118806 118806 128 (8,1,8,1) 0.50 N/A N/A 95379.45
512x512 3 1 118806 118806 128 (8,1,8,1) 0.60 N/A N/A 75719.90
512x512 3 1 118806 118806 128 (8,1,8,1) 0.70 N/A N/A 55887.92
512x512 3 1 118806 118806 128 (8,1,8,1) 0.80 N/A N/A 36444.14
512x512 3 1 118806 118806 128 (8,1,8,1) 0.90 N/A N/A 17844.47
512x1024 3 1 118806 118806 128 - 0.00 yes 165981.21 202457.74
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.05 N/A N/A 179652.09
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.10 N/A N/A 169494.82
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.20 N/A N/A 150344.48
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.30 N/A N/A 133085.55
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.40 N/A N/A 115268.55
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.50 N/A N/A 95859.71
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.60 N/A N/A 74771.77
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.70 N/A N/A 55201.18
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.80 N/A N/A 36086.59
512x1024 3 1 118806 118806 128 (4,1,8,1) 0.90 N/A N/A 18220.92
1024x128 3 1 118806 118806 128 - 0.00 yes 166277.64 199614.44
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.05 N/A N/A 168760.72
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.10 N/A N/A 164258.64
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.20 N/A N/A 155011.16
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.30 N/A N/A 145592.77
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.40 N/A N/A 133456.46
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.50 N/A N/A 117487.83
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.60 N/A N/A 98545.56
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.70 N/A N/A 77536.25
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.80 N/A N/A 54579.08
1024x128 3 1 118806 118806 128 (29,1,4,1) 0.90 N/A N/A 23215.96
1024x256 3 1 118806 118806 128 - 0.00 yes 165954.82 201113.94
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.05 N/A N/A 169928.98
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.10 N/A N/A 159145.01
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.20 N/A N/A 142113.40
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.30 N/A N/A 127223.95
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.40 N/A N/A 109795.79
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.50 N/A N/A 92943.02
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.60 N/A N/A 75015.89
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.70 N/A N/A 56391.08
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.80 N/A N/A 38051.85
1024x256 3 1 118806 118806 128 (15,1,4,1) 0.90 N/A N/A 18906.36
1024x512 3 1 118806 118806 128 - 0.00 yes 166466.77 202403.93
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.05 N/A N/A 184474.29
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.10 N/A N/A 175285.35
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.20 N/A N/A 156707.23
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.30 N/A N/A 136099.63
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.40 N/A N/A 116428.54
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.50 N/A N/A 96237.19
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.60 N/A N/A 75197.93
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.70 N/A N/A 55883.06
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.80 N/A N/A 36185.81
1024x512 3 1 118806 118806 128 (8,1,4,1) 0.90 N/A N/A 17193.75
1024x1024 3 1 118806 118806 128 - 0.00 yes 166565.60 203121.20
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.05 N/A N/A 183960.52
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.10 N/A N/A 174192.35
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.20 N/A N/A 156145.59
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.30 N/A N/A 136221.70
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.40 N/A N/A 115564.39
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.50 N/A N/A 94306.43
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.60 N/A N/A 74998.33
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.70 N/A N/A 55769.76
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.80 N/A N/A 36448.80
1024x1024 3 1 118806 118806 128 (4,1,4,1) 0.90 N/A N/A 17044.48
========================================================================================================================
Only sparsities 0.0 / 0.5 / 0.9 are kept above for brevity; the live run
emits every sparsity in SPARSITY_VALS for every pair in BLOCK_SHAPES. The
frame(L,R,T,B) column is - at sparsity 0 because the mask is dense, so
the frame has no effect on the kept blocks.
Speedup vs npu_fusion_attention (dense) — full 4 × 4 grid
Each cell is Ref_Latency(sparsity=0) / Our_Latency(sparsity=p). Values
< 1.0 mean BSA is slower than PFA dense; > 1.0 mean BSA is faster.
| Q\KV | sparsity | 128 | 256 | 512 | 1024 |
|---|---|---|---|---|---|
| 128 | 0.3 | 1.09× | 1.28× | 1.35× | 1.29× |
| 0.5 | 1.34× | 1.74× | 1.89× | 1.84× | |
| 0.8 | 2.84× | 4.19× | 4.65× | 4.63× | |
| 256 | 0.3 | 1.12× | 1.30× | 1.33× | 1.28× |
| 0.5 | 1.38× | 1.75× | 1.84× | 1.80× | |
| 0.8 | 2.94× | 4.30× | 4.73× | 4.63× | |
| 512 | 0.3 | 1.14× | 1.31× | 1.26× | 1.25× |
| 0.5 | 1.40× | 1.77× | 1.73× | 1.73× | |
| 0.8 | 2.99× | 4.21× | 4.54× | 4.60× | |
| 1024 | 0.3 | 1.14× | 1.30× | 1.22× | 1.22× |
| 0.5 | 1.42× | 1.79× | 1.73× | 1.77× | |
| 0.8 | 3.05× | 4.36× | 4.60× | 4.57× |
- Speedup grows monotonically with sparsity in every cell — BSA never regresses worse than ~0.78× even at sparsity 0, and reaches 6.4×–9.8× at sparsity 0.9 across the grid.
- BLOCK_SIZE_KV is the dominant lever. At sparsity 0.9, moving KV from 128 → 1024 buys 1.4×–1.5× on top of the existing speedup at every Q. The same trend holds for sparsity 0.8 (≈ 1.5×–1.6× from KV=128 to KV=1024). This is because a wider sabi entry means less sabi-tensor HBM traffic per kept block. Note however that KV=1024 is not always the strict winner at sparsity 0.9: for Q=128 and Q=512, KV=512 narrowly beats KV=1024 (9.00× vs 8.94×, 9.27× vs 9.11×) — at the coarsest KV the forced-frame footprint starts to eat into the kept-block budget.
- BLOCK_SIZE_Q has a much smaller effect. Q-tile re-parameterisation does
not change the cube tile width (
basicSInnerSize = 512for D=128 BF16) or the per-Q-row workload; the only effect is collapsing several 128-token Q groups onto a single sabi row, which is mostly free. At sparsity 0.9 the spread between Q=128 and Q=1024 (same KV) is ≤ 12 %. - The dense-mode regression (sparsity 0) is essentially flat across KV. The overhead vs PFA sits at ≈ 17–22 % across the whole grid (0.78×–0.83×) with no strong KV trend — the wide-KV packer no longer buys additional bandwidth at sparsity 0 in this configuration.
- 128×128 remains the most expensive baseline. Finest sabi resolution but the most metadata-per-kept-block.
- The fastest cell at high sparsity is KV=1024 at Q∈{256, 1024} — both reach 9.77× over PFA dense at sparsity 0.9. KV=512 also lands within ~5 % of the top (9.00×–9.68×). Picking a granularity is therefore an application-level trade-off between pattern-fidelity and end-to-end latency; the kernel does not force one.
The PFA reference (npu_fusion_attention dense) is shape-invariant — the
small differences in the Ref_Latency_[usec] column across block shapes are
timing noise from re-running the same workload 16 times.
3 - softmax_lse correctness tests (test_lse.py)
test_lse.py verifies the softmax_lse (log-sum-exp) second output of the kernel.
pytest test_lse.py -v
Every test case is exercised at all 16 default block shapes
({128, 256, 512, 1024} × {128, 256, 512, 1024}) via a parametrized
block_shape fixture, so the live test count is ~16× the per-shape count
below (~1600 instances total). A handful of cases at small S or high sparsity
are auto-skipped because the test parameterisation becomes inherently
unsatisfiable (forced-frame density exceeds 1 − sparsity, or the row's
expected kept-KV-block count drops below 1) — see the pytest.skip rules in
the file.
What is tested per block shape (B=1 only)
| Test | Reference | Shapes | Dtypes | Notes |
|---|---|---|---|---|
test_bsa_lse_vs_fias |
FIAS dense | 14 shapes, S up to 24 000 | bfloat16, float16 | Dense mode (no sabi) |
test_bsa_lse_vs_fias_sparse |
FIAS sparse_mode=1 | 8 shapes, S up to 24 000 | bfloat16 | H=1, no border frame |
test_bsa_lse_vs_fias_sparse_framed |
FIAS sparse_mode=1 | 2 shapes, S ≥ 16 384 | bfloat16 | H=1, FrameWidths(2,1,3,1) |
test_bsa_lse_sparse_multihead |
Python float32 reference | 4 shapes, S ≤ 1 024 | bfloat16 | H=2,4; Python loops over sabi |
test_bsa_lse_zero_input |
Analytical: log(S) | 5 shapes | bfloat16, float16 | Q=K=V=0 ⟹ uniform softmax |
test_bsa_lse_shape_and_dtype |
— | 3 shapes | bfloat16 | Checks shape=[B,H,S], dtype=float32 |
test_bsa_lse_disabled_returns_empty |
— | 3 shapes | bfloat16 | flag=False ⟹ numel=0 |
Why FIAS is the LSE reference
FIAS (npu_fused_infer_attention_score) already exposes a softmax_lse output via its softmaxLseFlag bool attr (output index 1, shape [B, N, S, 1], float32). BSA computes the identical quantity — log(Σ exp(q·kᵀ/√d)) over attended tokens — so FIAS with a matching token-level attention mask serves as a drop-in reference for H=1 shapes.
For H>1 shapes FIAS cannot be used directly (its sparse_mode=1 requires a [B,1,S,S] broadcast mask; passing [B,H,S,S] hangs the NPU), so a Python float32 reference that loops over the sabi indices is used instead.
Tolerances
ATOL = 0.001 RTOL = 0.01
float16 is excluded from sparse tests: FIAS processes -inf-masked tiles through its flash-attention tiling while BSA skips them entirely; the resulting LSE divergence exceeds any reasonable tolerance at high sparsity.
Test setup
Host CPU: aarch64
Device: Ascend 910B2
Device Driver: 25.3.rc1
Docker image: `docker pull --platform=arm64 swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.5.0-910b-ubuntu22.04-py3.10-ops`
OS: ubuntu: 22.04
CANN: 8.5.0-beta.1
Python: 3.11.10
torch: 2.8.0+cpu
torch_npu: 2.8.0