"""
Sparse Flash Attention Quantization Module
This module implements sparse flash attention with quantization support for DeepSeek V32.
It performs attention computation on top-k selected key-value pairs from cache,
supporting both standard and flash attention algorithms.
Main Functions:
- sparse_flash_attention_quant_compute: Standard sparse attention computation
- sparse_flash_attention_quant_compute_flash: Flash attention variant with online softmax
- sparse_flash_attention_quant_d: JIT-compiled decode version
- sparse_flash_attention_quant_p: JIT-compiled prefill version
Example:
See deepseekv32_sparse_flash_attention_quant.py for usage examples.
"""
import os
import math
from dataclasses import dataclass
import numpy as np
import pypto
from pypto.experimental import gather_in_l1, gather_in_ub
@dataclass
class SaTileShapeConfig:
g_tile: int
s_kv_tile: int
gather_vec_tile_shape: list
c1_tile_shape: list
v1_tile_shape: list
c2_tile_shape: list
v2_tile_shape: list
def sparse_flash_attention_quant_compute(query_nope, query_rope, key_nope_2d, key_rope_2d,
k_nope_scales, topk_indices, block_table, kv_act_seqs,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, max_blocknum_perbatch, tile_config):
"""Compute sparse flash attention with quantization support.
Performs attention computation on top-k selected key-value pairs from cache.
The function processes queries and keys in batches, computing attention scores
and aggregating values. Supports both quantized (INT8) and non-quantized keys.
Args:
query_nope: Query tensor without RoPE, shape (t * n_q, kv_lora_rank), dtype BF16
query_rope: Query tensor with RoPE, shape (t * n_q, rope_dim), dtype BF16
key_nope_2d: Key tensor without RoPE, shape (block_num * block_size, kv_lora_rank),
dtype BF16 or INT8
key_rope_2d: Key tensor with RoPE, shape (block_num * block_size, rope_dim), dtype BF16
k_nope_scales: Dequantization scales for quantized keys, shape (block_num * block_size, 4),
dtype FP32. Only used when key_nope_2d is INT8.
topk_indices: Top-k indices for each query token, shape (t, n_kv * topk), dtype INT32
block_table: Block mapping table for PagedAttention, shape (b, max_blocknum_perbatch),
dtype INT32
kv_act_seqs: Actual sequence lengths for each batch, shape (b,), dtype INT32
attention_out: Output attention tensor, shape (b, s, n_q, kv_lora_rank), dtype BF16
nq: Number of query heads
n_kv: Number of key-value heads
softmax_scale: Scaling factor for attention scores, typically 1/sqrt(head_dim)
topk: Number of top-k keys to attend to
block_size: Size of each block in PagedAttention
max_blocknum_perbatch: Maximum number of blocks per batch
tile_config: SaTileShapeConfig object containing tiling parameters:
- g_tile: Group tile size
- s_kv_tile: Key-value sequence tile size
- c1_tile_shape: Cube tile shape for first matmul
- v1_tile_shape: Vector tile shape for softmax
- c2_tile_shape: Cube tile shape for second matmul
Note:
The function uses nested loops to process batches, sequences, heads, and groups.
For quantized keys, it performs dequantization before attention computation.
The attention computation uses standard softmax normalization.
"""
dtype = query_nope.dtype
kn_dtype = key_nope_2d.dtype
dn = query_nope.shape[1]
dr = query_rope.shape[1]
group = nq // n_kv
gather_vec_tile = tile_config.gather_vec_tile_shape
group_tile = tile_config.g_tile
s2_tile = tile_config.s_kv_tile
c1_tile = tile_config.c1_tile_shape
v1_tile = tile_config.v1_tile_shape
c2_tile = tile_config.c2_tile_shape
n_kv_sym = n_kv
batch_size_sym = kv_act_seqs.shape[0]
s1_n2_gsym = query_nope.shape[0] // batch_size_sym
s1_sym = s1_n2_gsym // nq
g_loop_sym = group // group_tile
atten_out_2dim = pypto.tensor([batch_size_sym * s1_n2_gsym, dn], dtype, "attenOut2Dim")
for batch_idx in pypto.loop(0, batch_size_sym, 1, name="LOOP_L0_idx", idx_name="bIdx", parallel=True):
cur_act_seq = kv_act_seqs[batch_idx]
for slc_idx in pypto.loop(0, s1_sym, 1, name="LOOP_L1_s1_SA", idx_name="s1Idx"):
cur_seq = (cur_act_seq - s1_sym + 1 + slc_idx).max(0).min(topk)
cur_seq.as_variable()
bn_per_batch = (cur_seq + s2_tile - 1) // s2_tile
for n_kv_idx in pypto.loop(0, n_kv_sym, 1, name="LOOP_L2_n_kv_SA", idx_name="n_kvIdx"):
for group_idx in pypto.loop(0, g_loop_sym, 1, name="LOOP_L3_g_SA", idx_name="gIdx"):
cur_group_tile = group_tile
cur_offset = batch_idx * s1_n2_gsym + slc_idx * nq + n_kv_idx * group + group_idx * cur_group_tile
for s2_idx, _ in pypto.loop_unroll(0, bn_per_batch, 1,
name="LOOP_L4_s2_SA", idx_name="s2_idx", unroll_list={1}):
cur_s2_tile = s2_tile
cur_topk_indices = pypto.view(topk_indices, [1, cur_s2_tile],
[batch_idx * s1_sym + slc_idx, s2_idx * cur_s2_tile],
valid_shape=[1, (cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile)])
cur_block_table = pypto.view(block_table, [1, max_blocknum_perbatch], [batch_idx, 0])
kn = pypto.tensor([s2_tile, dn], dtype, "kn")
if kn_dtype == pypto.DT_INT8:
pypto.set_semantic_label("Sa_V0")
pypto.set_vec_tile_shapes(16, 1024)
k_nope_scale_view = pypto.view(k_nope_scales, [k_nope_scales.shape[0], 8],
[0, 0], valid_shape=[k_nope_scales.shape[0], 4])
kn_scale = gather_in_ub(k_nope_scale_view, cur_topk_indices, cur_block_table,
block_size, -2)
k_nope_2d_view = pypto.view(key_nope_2d, [key_nope_2d.shape[0], dn],
[0, 0], valid_shape=[key_nope_2d.shape[0], dn])
kn_quant = gather_in_ub(k_nope_2d_view, cur_topk_indices, cur_block_table, block_size, -2)
kn_quant_fp16 = pypto.cast(kn_quant, pypto.DT_FP16)
kn_quant_fp32 = pypto.cast(kn_quant_fp16, pypto.DT_FP32)
kn_quant_fp32 = pypto.concat([kn_quant_fp32, kn_quant_fp32], -1)
kn_quant_fp32_tmp = pypto.reshape(kn_quant_fp32, [s2_tile * 8, 128])
kn_scale_tmp = pypto.reshape(kn_scale, [s2_tile * 8, 1])
pypto.set_vec_tile_shapes(128, 128)
kn_fp32 = pypto.mul(kn_quant_fp32_tmp, kn_scale_tmp)
kn_fp32_reshape = pypto.reshape(kn_fp32, [s2_tile, dn * 2])
pypto.set_vec_tile_shapes(16, 512)
cur_kn_fp32 = pypto.view(kn_fp32_reshape, [cur_s2_tile, dn], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn])
kn = pypto.cast(cur_kn_fp32, dtype)
pypto.set_semantic_label("Sa_C1")
pypto.set_vec_tile_shapes(gather_vec_tile[0], gather_vec_tile[1])
pypto.set_cube_tile_shapes([c1_tile[0],
c1_tile[1]], [c1_tile[2], c1_tile[3]], [c1_tile[4], c1_tile[5]])
kr = gather_in_l1(key_rope_2d, cur_topk_indices, cur_block_table, block_size, dr,
is_b_matrix=True, is_trans=True)
kj = pypto.tensor([cur_s2_tile, dn + dr], dtype, "kj")
pypto.assemble(kn, [0, 0], kj)
pypto.assemble(kr, [0, dn], kj)
kj_view = pypto.view(kj, [cur_s2_tile, dn + dr], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn + dr])
qn = pypto.view(query_nope, [cur_group_tile, dn], [cur_offset, 0],
valid_shape=[cur_group_tile, dn])
qr = pypto.view(query_rope, [cur_group_tile, dr], [cur_offset, 0],
valid_shape=[cur_group_tile, dr])
qi = pypto.tensor([cur_group_tile, dn + dr], dtype, "qi")
pypto.assemble(qn, [0, 0], qi)
pypto.assemble(qr, [0, dn], qi)
sij = pypto.matmul(qi, kj_view, pypto.DT_FP32, a_trans=False, b_trans=True)
else:
if pypto.platform.npuarch == 'DAV_3510':
pypto.set_pass_options(sg_set_scope=20001)
pypto.set_semantic_label("Sa_V0")
pypto.set_vec_tile_shapes(gather_vec_tile[0], gather_vec_tile[1])
k_nope_2d_view = pypto.view(key_nope_2d, [key_nope_2d.shape[0], dn],
[0, 0], valid_shape=[key_nope_2d.shape[0], dn])
kn = gather_in_ub(k_nope_2d_view, cur_topk_indices, cur_block_table, block_size, -2)
pypto.set_semantic_label("Sa_C1")
pypto.set_vec_tile_shapes(gather_vec_tile[0], gather_vec_tile[1])
pypto.set_cube_tile_shapes([c1_tile[0],
c1_tile[1]], [c1_tile[2], c1_tile[3]], [c1_tile[4], c1_tile[5]])
key_rope_2d_view = pypto.view(key_rope_2d, [key_rope_2d.shape[0], dr],
[0, 0], valid_shape=[key_rope_2d.shape[0], dr])
kr = gather_in_ub(key_rope_2d_view, cur_topk_indices, cur_block_table, block_size, -2)
kj = pypto.tensor([cur_s2_tile, dn + dr], dtype, "kj")
pypto.assemble(kn, [0, 0], kj)
pypto.assemble(kr, [0, dn], kj)
kj_view = pypto.view(kj, [cur_s2_tile, dn + dr], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn + dr])
qn = pypto.view(query_nope, [cur_group_tile, dn], [cur_offset, 0],
valid_shape=[cur_group_tile, dn])
qr = pypto.view(query_rope, [cur_group_tile, dr], [cur_offset, 0],
valid_shape=[cur_group_tile, dr])
qi = pypto.tensor([cur_group_tile, dn + dr], dtype, "qi")
pypto.assemble(qn, [0, 0], qi)
pypto.assemble(qr, [0, dn], qi)
sij = pypto.matmul(qi, kj_view, pypto.DT_FP32, a_trans=False, b_trans=True)
if pypto.platform.npuarch == 'DAV_3510':
pypto.set_pass_options(sg_set_scope=-1)
pypto.set_semantic_label("Sa_V1")
pypto.set_vec_tile_shapes(v1_tile[0], v1_tile[1])
sij_scale = pypto.mul(sij, softmax_scale)
tilda_mij_reduce = pypto.amax(sij_scale, dim=-1, keepdim=True)
t_sub = pypto.sub(sij_scale, tilda_mij_reduce)
tilda_pij = pypto.exp(t_sub)
tilda_lij_reduce = pypto.sum(tilda_pij, dim=-1, keepdim=True)
t_softmax = pypto.div(tilda_pij, tilda_lij_reduce, pypto.PrecisionType.INTRINSIC)
tilda_pij_f16 = pypto.cast(t_softmax, dtype)
pypto.set_semantic_label("Sa_C2")
pypto.set_cube_tile_shapes([c2_tile[0],
c2_tile[1]], [c2_tile[2], c2_tile[3]], [c2_tile[4], c2_tile[5]])
pypto.set_matrix_size([tilda_pij_f16.shape[0],
tilda_pij_f16.shape[1], kn.shape[1]])
q1 = pypto.tensor([cur_group_tile, dn], dtype)
vj = pypto.view(kn, [cur_s2_tile, dn], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn])
q1 = pypto.matmul(tilda_pij_f16, vj, dtype)
pypto.assemble(q1, [cur_offset, 0], atten_out_2dim)
attention_out[:] = pypto.reshape(atten_out_2dim,
[attention_out.shape[0], attention_out.shape[1],
attention_out.shape[2], attention_out.shape[3]], inplace=True)
def sparse_flash_attention_quant_compute_flash(query_nope, query_rope, key_nope_2d, key_rope_2d,
k_nope_scales, topk_indices, block_table, kv_act_seqs,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, max_blocknum_perbatch, tile_config):
"""Compute sparse flash attention with online softmax (flash attention variant).
Implements flash attention algorithm with online softmax computation for better
numerical stability and memory efficiency. Uses incremental updates of attention
output, normalization factor, and maximum values across key-value blocks.
Args:
query_nope: Query tensor without RoPE, shape (t * n_q, kv_lora_rank), dtype BF16
query_rope: Query tensor with RoPE, shape (t * n_q, rope_dim), dtype BF16
key_nope_2d: Key tensor without RoPE, shape (block_num * block_size, kv_lora_rank),
dtype BF16 or INT8
key_rope_2d: Key tensor with RoPE, shape (block_num * block_size, rope_dim), dtype BF16
k_nope_scales: Dequantization scales for quantized keys, shape (block_num * block_size, 4),
dtype FP32. Only used when key_nope_2d is INT8.
topk_indices: Top-k indices for each query token, shape (t, n_kv * topk), dtype INT32
block_table: Block mapping table for PagedAttention, shape (b, max_blocknum_perbatch),
dtype INT32
kv_act_seqs: Actual sequence lengths for each batch, shape (b,), dtype INT32
attention_out: Output attention tensor, shape (b, s, n_q, kv_lora_rank), dtype BF16
nq: Number of query heads
n_kv: Number of key-value heads
softmax_scale: Scaling factor for attention scores, typically 1/sqrt(head_dim)
topk: Number of top-k keys to attend to
block_size: Size of each block in PagedAttention
max_blocknum_perbatch: Maximum number of blocks per batch
tile_config: SaTileShapeConfig object containing tiling parameters, including
v2_tile_shape for flash attention updates
Note:
Flash attention algorithm maintains running statistics:
- oi_update: Running attention output
- li_update: Running normalization factor (sum of exp values)
- mi_update: Running maximum value
These are incrementally updated across key-value blocks using the online softmax
formula to maintain numerical stability.
"""
dtype = query_nope.dtype
kn_dtype = key_nope_2d.dtype
dn = query_nope.shape[1]
dr = query_rope.shape[1]
group = nq // n_kv
group_tile = tile_config.g_tile
s2_tile = tile_config.s_kv_tile
c1_tile = tile_config.c1_tile_shape
v1_tile = tile_config.v1_tile_shape
c2_tile = tile_config.c2_tile_shape
v2_tile = tile_config.v2_tile_shape
n_kv_sym = n_kv
batch_size_sym = kv_act_seqs.shape[0]
s1_n2_gsym = query_nope.shape[0] // batch_size_sym
s1_sym = s1_n2_gsym // nq
g_loop_sym = group // group_tile
for batch_idx in pypto.loop(0, batch_size_sym, 1, name="FLASH_LOOP_L0_idx", idx_name="bIdx"):
cur_act_seq = kv_act_seqs[batch_idx]
for slc_idx in pypto.loop(0, s1_sym, 1, name="FLASH_LOOP_L1_s1_SA", idx_name="s1Idx"):
cur_seq = (cur_act_seq - s1_sym + 1 + slc_idx).max(0).min(topk)
cur_seq.as_variable()
bn_per_batch = (cur_seq + s2_tile - 1) // s2_tile
for n_kv_idx in pypto.loop(0, n_kv_sym, 1, name="FLASH_LOOP_L2_n_kv_SA", idx_name="n_kvIdx"):
for group_idx in pypto.loop(0, g_loop_sym, 1, name="FLASH_LOOP_L3_g_SA", idx_name="gIdx"):
cur_group_tile = group_tile
oi_update = pypto.tensor([cur_group_tile, dn], pypto.DT_FP32, "oi_update")
li_update = pypto.tensor([1, cur_group_tile], pypto.DT_FP32, "li_update")
mi_update = pypto.tensor([1, cur_group_tile], pypto.DT_FP32, "mi_update")
cur_offset = batch_idx * s1_n2_gsym + slc_idx * nq + n_kv_idx * group + group_idx * cur_group_tile
oi_offset = [batch_idx, slc_idx, n_kv_idx * group + group_idx * cur_group_tile, 0]
for s2_idx, _ in pypto.loop_unroll(0, bn_per_batch, 1,
name="FLASH_LOOP_L4_s2_SA", idx_name="s2_idx", unroll_list={1}):
cur_s2_tile = s2_tile
pypto.set_semantic_label("Sa_V0")
cur_topk_indices = pypto.view(topk_indices, [1, cur_s2_tile],
[batch_idx * s1_sym + slc_idx, s2_idx * cur_s2_tile],
valid_shape=[1, (cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile)])
cur_block_table = pypto.view(block_table, [1, max_blocknum_perbatch], [batch_idx, 0])
k_nope_2d_view = pypto.view(key_nope_2d, [key_nope_2d.shape[0], dn],
[0, 0], valid_shape=[key_nope_2d.shape[0], dn])
k_nope_scale_view = pypto.view(k_nope_scales, [k_nope_scales.shape[0], 4],
[0, 0], valid_shape=[k_nope_scales.shape[0], 4])
kn = pypto.tensor([s2_tile, dn], dtype, "kn")
if kn_dtype == pypto.DT_INT8:
pypto.set_vec_tile_shapes(32, 512)
kn_scale = gather_in_ub(k_nope_scale_view, cur_topk_indices,
cur_block_table, block_size, -2)
kn_quant = gather_in_ub(k_nope_2d_view, cur_topk_indices, cur_block_table, block_size, -2)
kn_quant_fp16 = pypto.cast(kn_quant, pypto.DT_FP16)
kn_quant_fp32 = pypto.cast(kn_quant_fp16, pypto.DT_FP32)
kn_quant_fp32_tmp = pypto.reshape(kn_quant_fp32, [s2_tile * 4, 128])
kn_scale_tmp = pypto.reshape(kn_scale, [s2_tile * 4, 1])
pypto.set_vec_tile_shapes(128, 128)
kn_fp32 = pypto.mul(kn_quant_fp32_tmp, kn_scale_tmp)
kn_fp32_reshape = pypto.reshape(kn_fp32, [s2_tile, dn])
pypto.set_vec_tile_shapes(32, 512)
cur_kn_fp32 = pypto.view(kn_fp32_reshape, [cur_s2_tile, dn], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn])
kn = pypto.cast(cur_kn_fp32, dtype)
else:
pypto.set_cube_tile_shapes([c1_tile[0], c1_tile[1]],
[c1_tile[2], c1_tile[3]], [c1_tile[4], c1_tile[5]])
kn = gather_in_l1(key_nope_2d,
cur_topk_indices, cur_block_table, block_size, dn, is_b_matrix=True, is_trans=True)
pypto.set_semantic_label("Sa_C1")
pypto.set_cube_tile_shapes([c1_tile[0],
c1_tile[1]], [c1_tile[2], c1_tile[3]], [c1_tile[4], c1_tile[5]])
kr = gather_in_l1(key_rope_2d, cur_topk_indices, cur_block_table, block_size, dr,
is_b_matrix=True, is_trans=True)
kj = pypto.tensor([cur_s2_tile, dn + dr], dtype, "kj")
pypto.assemble(kn, [0, 0], kj)
pypto.assemble(kr, [0, dn], kj)
kj_view = pypto.view(kj, [cur_s2_tile, dn + dr], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn + dr])
qn = pypto.view(query_nope, [cur_group_tile, dn], [cur_offset, 0],
valid_shape=[cur_group_tile, dn])
qr = pypto.view(query_rope, [cur_group_tile, dr], [cur_offset, 0],
valid_shape=[cur_group_tile, dr])
qi = pypto.tensor([cur_group_tile, dn + dr], dtype, "qi")
pypto.assemble(qn, [0, 0], qi)
pypto.assemble(qr, [0, dn], qi)
sij = pypto.matmul(qi, kj_view, pypto.DT_FP32, a_trans=False, b_trans=True)
pypto.set_semantic_label("Sa_V1")
pypto.set_vec_tile_shapes(v1_tile[0], v1_tile[1])
sij_scale = pypto.mul(sij, softmax_scale)
tilda_mij_reduce = pypto.amax(sij_scale, dim=-1, keepdim=True)
tilda_mij = pypto.reshape(tilda_mij_reduce, [1, cur_group_tile])
t_sub = pypto.sub(sij_scale, tilda_mij_reduce)
tilda_pij = pypto.exp(t_sub)
tilda_pij_f16 = pypto.cast(tilda_pij, dtype)
tilda_lij_reduce = pypto.sum(tilda_pij, dim=-1, keepdim=True)
tilda_lij = pypto.reshape(tilda_lij_reduce, [1, cur_group_tile])
pypto.set_semantic_label("Sa_C2")
pypto.set_cube_tile_shapes([c2_tile[0],
c2_tile[1]], [c2_tile[2], c2_tile[3]], [c2_tile[4], c2_tile[5]])
pypto.set_matrix_size([tilda_pij_f16.shape[0],
tilda_pij_f16.shape[1], kn.shape[1]])
q1 = pypto.tensor([cur_group_tile, dn], dtype)
if kn_dtype == pypto.DT_INT8:
vj = pypto.view(kn, [cur_s2_tile, dn], [0, 0],
valid_shape=[(cur_seq - s2_idx * cur_s2_tile).min(cur_s2_tile), dn])
q1 = pypto.matmul(tilda_pij_f16, vj, pypto.DT_FP32)
else:
vj = gather_in_l1(key_nope_2d, cur_topk_indices, cur_block_table, block_size,
dn, is_b_matrix=True, is_trans=False)
q1 = pypto.matmul(tilda_pij_f16, vj, pypto.DT_FP32)
if pypto.cond(pypto.is_loop_begin(s2_idx)):
oi_tmp = q1
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
if pypto.cond(pypto.is_loop_end(s2_idx)):
pypto.set_semantic_label("Sa_V2")
oi_update[:] = oi_tmp / tilda_lij_reduce
pypto.set_vec_tile_shapes(1, 1, v2_tile[0], v2_tile[1])
oi_update_4_dim = pypto.cast(pypto.reshape(oi_update,
[1, 1, cur_group_tile, dn]), dtype)
pypto.assemble(oi_update_4_dim, oi_offset, attention_out)
else:
oi_update[:] = oi_tmp
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
li_update[:] = tilda_lij
mi_update[:] = tilda_mij
else:
pypto.set_semantic_label("Sa_UpdateVec2")
oi = oi_update
li = li_update
mi = mi_update
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
mi_new = pypto.maximum(mi, tilda_mij)
t1 = pypto.sub(mi, mi_new)
t2 = pypto.exp(t1)
t3 = pypto.sub(tilda_mij, mi_new)
t4 = pypto.exp(t3)
t5 = pypto.mul(t4, tilda_lij)
t6 = pypto.mul(t2, li)
li_new = pypto.add(t6, t5)
q3 = pypto.mul(oi, pypto.reshape(t2, [cur_group_tile, 1]))
pypto.set_vec_tile_shapes(v2_tile[0], v2_tile[1])
q2 = pypto.mul(q1, pypto.reshape(t4, [cur_group_tile, 1]))
oi_tmp = pypto.add(q3, q2)
if pypto.cond(pypto.is_loop_end(s2_idx)):
oi_update[:] = pypto.div(oi_tmp,
pypto.reshape(li_new, [cur_group_tile, 1]), pypto.PrecisionType.INTRINSIC)
pypto.set_vec_tile_shapes(1, 1, v2_tile[0], v2_tile[1])
oi_update_4_dim = pypto.cast(pypto.reshape(oi_update,
[1, 1, cur_group_tile, dn]), dtype)
pypto.assemble(oi_update_4_dim, oi_offset, attention_out)
else:
oi_update[:] = oi_tmp
li_update[:] = li_new
mi_update[:] = mi_new
@pypto.frontend.jit(
pass_options={
"vec_nbuffer_setting": {-1: 4, -2: 1},
"cube_l1_reuse_setting": {-1: 8},
},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 3,
"ready_on_host_tensors": ["block_table", "kv_act_seqs"]
}
)
def sparse_flash_attention_quant_d_950(
query_nope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
query_rope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
key_nope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC], ),
key_rope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
k_nope_scales: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
topk_indices: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
kv_act_seqs: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT32),
attention_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
nq, n_kv, softmax_scale, topk, block_size, max_blocknum_perbatch, tile_config
):
"""JIT-compiled sparse flash attention for decode phase.
Optimized version for decode phase with specific pass configurations.
Uses flash attention algorithm with online softmax for numerical stability.
Args:
query_nope: Query tensor without RoPE, shape (t * n_q, kv_lora_rank), dtype BF16
query_rope: Query tensor with RoPE, shape (t * n_q, rope_dim), dtype BF16
key_nope_2d: Key tensor without RoPE, shape (block_num * block_size, kv_lora_rank),
dtype BF16 or INT8
key_rope_2d: Key tensor with RoPE, shape (block_num * block_size, rope_dim), dtype BF16
k_nope_scales: Dequantization scales for quantized keys, shape (block_num * block_size, 4),
dtype FP32
topk_indices: Top-k indices for each query token, shape (t, n_kv * topk), dtype INT32
block_table: Block mapping table for PagedAttention, shape (b, max_blocknum_perbatch),
dtype INT32
kv_act_seqs: Actual sequence lengths for each batch, shape (b,), dtype INT32
attention_out: Output attention tensor, shape (b, s, n_q, kv_lora_rank), dtype BF16
nq: Number of query heads
n_kv: Number of key-value heads
softmax_scale: Scaling factor for attention scores
topk: Number of top-k keys to attend to
block_size: Size of each block in PagedAttention
max_blocknum_perbatch: Maximum number of blocks per batch
tile_config: SaTileShapeConfig object containing tiling parameters
Note:
Configured for decode phase with optimized memory and parallelism settings.
Uses flash attention algorithm for better numerical stability.
"""
pypto.experimental.set_operation_options(combine_axis=True)
sparse_flash_attention_quant_compute(query_nope, query_rope, key_nope_2d, key_rope_2d,
k_nope_scales, topk_indices, block_table, kv_act_seqs,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, max_blocknum_perbatch, tile_config)
@pypto.frontend.jit(
pass_options={
"vec_nbuffer_setting": {-1: 2, 0: 8},
"cube_l1_reuse_setting": {-1: 2},
},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 3,
"ready_on_host_tensors": ["block_table", "kv_act_seqs"]
}
)
def sparse_flash_attention_quant_d(
query_nope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
query_rope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
key_nope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC], ),
key_rope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
k_nope_scales: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
topk_indices: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
kv_act_seqs: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT32),
attention_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
nq, n_kv, softmax_scale, topk, block_size, max_blocknum_perbatch, tile_config
):
"""JIT-compiled sparse flash attention for decode phase.
Optimized version for decode phase with specific pass configurations.
Uses flash attention algorithm with online softmax for numerical stability.
Args:
query_nope: Query tensor without RoPE, shape (t * n_q, kv_lora_rank), dtype BF16
query_rope: Query tensor with RoPE, shape (t * n_q, rope_dim), dtype BF16
key_nope_2d: Key tensor without RoPE, shape (block_num * block_size, kv_lora_rank),
dtype BF16 or INT8
key_rope_2d: Key tensor with RoPE, shape (block_num * block_size, rope_dim), dtype BF16
k_nope_scales: Dequantization scales for quantized keys, shape (block_num * block_size, 4),
dtype FP32
topk_indices: Top-k indices for each query token, shape (t, n_kv * topk), dtype INT32
block_table: Block mapping table for PagedAttention, shape (b, max_blocknum_perbatch),
dtype INT32
kv_act_seqs: Actual sequence lengths for each batch, shape (b,), dtype INT32
attention_out: Output attention tensor, shape (b, s, n_q, kv_lora_rank), dtype BF16
nq: Number of query heads
n_kv: Number of key-value heads
softmax_scale: Scaling factor for attention scores
topk: Number of top-k keys to attend to
block_size: Size of each block in PagedAttention
max_blocknum_perbatch: Maximum number of blocks per batch
tile_config: SaTileShapeConfig object containing tiling parameters
Note:
Configured for decode phase with optimized memory and parallelism settings.
Uses flash attention algorithm for better numerical stability.
"""
pypto.experimental.set_operation_options(combine_axis=True)
sparse_flash_attention_quant_compute(query_nope, query_rope, key_nope_2d, key_rope_2d,
k_nope_scales, topk_indices, block_table, kv_act_seqs,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, max_blocknum_perbatch, tile_config)
@pypto.frontend.jit(
pass_options={
"vec_nbuffer_setting": {-1: 4, 0: 16},
"cube_l1_reuse_setting": {-1: 4},
},
runtime_options={
"stitch_function_max_num": 128,
"ready_on_host_tensors": ["block_table", "kv_act_seqs"]
}
)
def sparse_flash_attention_quant_p(
query_nope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
query_rope: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
key_nope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC],),
key_rope_2d: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
k_nope_scales: pypto.Tensor([pypto.STATIC, pypto.STATIC], pypto.DT_FP32),
topk_indices: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
kv_act_seqs: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT32),
attention_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC, pypto.STATIC, pypto.STATIC], pypto.DT_BF16),
nq, n_kv, softmax_scale, topk, block_size, max_blocknum_perbatch, tile_config
):
"""JIT-compiled sparse flash attention for prefill phase.
Optimized version for prefill phase with specific pass configurations.
Uses flash attention algorithm with online softmax for numerical stability.
Args:
query_nope: Query tensor without RoPE, shape (t * n_q, kv_lora_rank), dtype BF16
query_rope: Query tensor with RoPE, shape (t * n_q, rope_dim), dtype BF16
key_nope_2d: Key tensor without RoPE, shape (block_num * block_size, kv_lora_rank),
dtype BF16 or INT8
key_rope_2d: Key tensor with RoPE, shape (block_num * block_size, rope_dim), dtype BF16
k_nope_scales: Dequantization scales for quantized keys, shape (block_num * block_size, 4),
dtype FP32
topk_indices: Top-k indices for each query token, shape (t, n_kv * topk), dtype INT32
block_table: Block mapping table for PagedAttention, shape (b, max_blocknum_perbatch),
dtype INT32
kv_act_seqs: Actual sequence lengths for each batch, shape (b,), dtype INT32
attention_out: Output attention tensor, shape (b, s, n_q, kv_lora_rank), dtype BF16
nq: Number of query heads
n_kv: Number of key-value heads
softmax_scale: Scaling factor for attention scores
topk: Number of top-k keys to attend to
block_size: Size of each block in PagedAttention
max_blocknum_perbatch: Maximum number of blocks per batch
tile_config: SaTileShapeConfig object containing tiling parameters
Note:
Configured for prefill phase with optimized memory and parallelism settings.
Uses flash attention algorithm for better numerical stability.
"""
pypto.experimental.set_operation_options(combine_axis=True)
sparse_flash_attention_quant_compute(query_nope, query_rope, key_nope_2d, key_rope_2d,
k_nope_scales, topk_indices, block_table, kv_act_seqs,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, max_blocknum_perbatch, tile_config)