"""
Sparse compress Flash Attention Module
This module implements sparse flash attention with quantization support for DeepSeek V4.
It performs attention computation on top-k selected key-value pairs from cache,
supporting both standard and flash attention algorithms.
Main Functions:
- sparse_compress_flash_attention_compute: Standard sparse attention computation
- sparse_compress_flash_attention_flash: Flash attention variant with online softmax
- sparse_compress_flash_attention_d: JIT-compiled decode version
Example:
See test_sparse_compress_flash_attention.py for usage examples.
"""
from dataclasses import dataclass
import pypto
import torch
from pypto.experimental import gather_in_ub
from torch._dynamo import allow_in_graph
from torch._subclasses.fake_tensor import FakeTensor
MAX_S2 = 131072
@dataclass
class SCFATileShapeConfig:
g_tile: int
c1_tile_shape: list
v1_tile_shape: list
c2_tile_shape: list
def sparse_compress_flash_attention_compute(query, actual_seq_q, ori_kv, cmp_kv, ori_block_table,
cmp_block_table, atten_sink,
seqused_kv, cmp_sparse_indices,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, win_size, cmp_ratio, tile_config):
"""Compute sparse compress flash attention for prefill.
"""
dtype = query.dtype
d = query.shape[1]
group_tile = tile_config.g_tile
c1_tile = tile_config.c1_tile_shape
v1_tile = tile_config.v1_tile_shape
c2_tile = tile_config.c2_tile_shape
batch_size_sym = seqused_kv.shape[0]
topk_tile = topk
sel_tile = win_size * 2 + topk_tile
kv_tile = win_size + topk_tile
for batch_idx in pypto.loop(0, batch_size_sym, 1, name="LOOP_L0_idx", idx_name="bIdx"):
cur_s1 = actual_seq_q[batch_idx + 1] - actual_seq_q[batch_idx]
ori_act_seq = seqused_kv[batch_idx]
for slc_idx in pypto.loop(0, cur_s1, 1, name="LOOP_L1_s1_SA", idx_name="s1Idx"):
cur_len = pypto.max(ori_act_seq - cur_s1 + 1 + slc_idx, 0)
cur_win_size = pypto.min(cur_len, win_size)
cur_valid_start_pos = cur_len - cur_win_size
cur_valid_end_pos = cur_len - 1
start_block = cur_valid_start_pos // block_size
start_offset = cur_valid_start_pos % block_size
end_block = cur_valid_end_pos // block_size
physical_block_id_0 = ori_block_table[batch_idx, start_block]
physical_block_id_1 = ori_block_table[batch_idx, end_block]
cur_topk_size = (ori_act_seq // cmp_ratio - cur_s1 + 1 + slc_idx).max(0).min(topk)
cur_s2_tile = cur_win_size + cur_topk_size
cur_group_tile = group_tile
t_idx = actual_seq_q[batch_idx] + slc_idx
cur_offset = t_idx * nq
pypto.set_semantic_label("Sa_V0")
pypto.set_vec_tile_shapes(128, 512)
kj = pypto.tensor([sel_tile, d], dtype, "kj")
cur_kv_block_0_size = cur_win_size - start_offset
kv_block_0 = pypto.view(ori_kv, [win_size, d], [physical_block_id_0 * block_size + start_offset, 0], \
valid_shape=[cur_kv_block_0_size, d])
pypto.assemble(pypto.clone(kv_block_0), [0, 0], kj)
if pypto.cond(start_block < end_block):
pypto.set_vec_tile_shapes(128, 512)
cur_kv_block_1_size = cur_win_size - cur_kv_block_0_size
kv_block_1 = pypto.view(ori_kv, [win_size, d], [physical_block_id_1 * block_size, 0], \
valid_shape=[cur_kv_block_1_size, d])
pypto.assemble(pypto.clone(kv_block_1), [cur_kv_block_0_size, 0], kj)
pypto.set_vec_tile_shapes(128, 512)
cur_cmp_sparse_indices = pypto.view(cmp_sparse_indices, [1, topk_tile], [t_idx, 0], \
valid_shape=[1, cur_topk_size])
cur_block_table = pypto.view(cmp_block_table, [1, MAX_S2 // block_size], [batch_idx, 0])
cmp_kv_view = pypto.view(cmp_kv, [topk_tile, d], [0, 0], valid_shape=[cur_topk_size, d])
compress_kv = gather_in_ub(cmp_kv_view, cur_cmp_sparse_indices, cur_block_table, block_size, -2)
pypto.assemble(compress_kv, [cur_win_size, 0], kj)
pypto.set_semantic_label("Sa_C1")
pypto.set_cube_tile_shapes([c1_tile[0], c1_tile[1]], [c1_tile[2], c1_tile[3]],
[c1_tile[4], c1_tile[5]])
qv = pypto.view(query, [cur_group_tile, d], [cur_offset, 0], valid_shape=[cur_group_tile, d])
kv_after_gather = pypto.view(kj, [kv_tile, d], [0, 0], valid_shape=[cur_s2_tile, d])
sij = pypto.matmul(qv, kv_after_gather, pypto.DT_FP32, a_trans=False, b_trans=True)
pypto.set_semantic_label("Sa_V1")
pypto.set_vec_tile_shapes(v1_tile[0], v1_tile[1])
sij_scale = pypto.mul(sij, softmax_scale)
tilda_mij_reduce = pypto.amax(sij_scale, dim=-1, keepdim=True)
t_sub = pypto.sub(sij_scale, tilda_mij_reduce)
tilda_pij = pypto.exp(t_sub)
tilda_lij_reduce = pypto.sum(tilda_pij, dim=-1, keepdim=True)
atten_sink_2d = pypto.reshape(atten_sink, [atten_sink.shape[0], 1], inplace=True)
sink_sub_res = pypto.sub(atten_sink_2d, tilda_mij_reduce)
sink_exp_res = pypto.exp(sink_sub_res)
tilda_lij_reduce = pypto.add(tilda_lij_reduce, sink_exp_res)
t_softmax = pypto.div(tilda_pij, tilda_lij_reduce)
tilda_pij_f16 = pypto.cast(t_softmax, dtype)
pypto.set_semantic_label("Sa_C2")
pypto.set_cube_tile_shapes([c2_tile[0], c2_tile[1]], [c2_tile[2], c2_tile[3]],
[c2_tile[4], c2_tile[5]])
pypto.set_matrix_size([tilda_pij_f16.shape[0], tilda_pij_f16.shape[1], kj.shape[1]])
vj = pypto.view(kj, [kv_tile, d], [0, 0], valid_shape=[cur_s2_tile, d])
q1 = pypto.matmul(tilda_pij_f16, vj, dtype)
pypto.assemble(q1, [cur_offset, 0], attention_out)
@pypto.frontend.jit(
pass_options={
"cube_l1_reuse_setting": {-1: 2, 0: 8},
},
runtime_options={
"stitch_function_max_num": 128,
"device_sched_mode": 1
},
)
def sparse_compress_flash_attention_kernel(
query: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
actual_seq_q: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT32),
ori_kv: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
cmp_kv: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
ori_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
cmp_block_table: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
atten_sink: pypto.Tensor([pypto.STATIC], pypto.DT_FP32),
seqused_kv: pypto.Tensor([pypto.DYNAMIC], pypto.DT_INT32),
cmp_sparse_indices: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_INT32),
attention_out: pypto.Tensor([pypto.DYNAMIC, pypto.STATIC], pypto.DT_BF16),
nq, n_kv, softmax_scale, topk, block_size, win_size, cmp_ratio, tile_config):
"""JIT-compiled sparse compress flash attention for decode phase.
"""
pypto.experimental.set_operation_options(combine_axis=True)
sparse_compress_flash_attention_compute(query, actual_seq_q, ori_kv, cmp_kv, ori_block_table,
cmp_block_table, atten_sink,
seqused_kv, cmp_sparse_indices,
attention_out, nq, n_kv, softmax_scale, topk,
block_size, win_size, cmp_ratio, tile_config)
def check_input_output_shape_dtype(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu, atten_sink_npu,
cmp_sparse_indices_npu):
assert q_act_seqs_npu is not None and q_act_seqs_npu.dim() == 1, \
f"q_act_seqs_npu dim num is {q_act_seqs_npu.dim()}, expected 1"
assert query_npu.dim() == 2 and query_npu.size(1) == 512 and query_npu.dtype == torch.bfloat16, \
f"query dim num is {query_npu.dim()}, query axis 1 is {query_npu.size(1)}, \
query dtype is {query_npu.dtype}, expected 2, 512, torch.bfloat16"
assert ori_kv_npu.dim() == 2 and ori_kv_npu.size(1) == 512 and ori_kv_npu.dtype == torch.bfloat16, \
f"ori_kv_npu dim num is {ori_kv_npu.dim()}, ori_kv_npu axis 1 is {ori_kv_npu.size(1)}, \
ori_kv_npu dtype is {ori_kv_npu.dtype}, expected 2, 512, torch.bfloat16"
assert cmp_kv_npu.dim() == 2 and cmp_kv_npu.size(1) == 512 and cmp_kv_npu.dtype == torch.bfloat16, \
f"cmp_kv_npu dim num is {cmp_kv_npu.dim()}, cmp_kv_npu axis 1 is {cmp_kv_npu.size(1)}, \
cmp_kv_npu dtype is {cmp_kv_npu.dtype}, expected 2, 512, torch.bfloat16"
assert atten_sink_npu.dim() == 1 and atten_sink_npu.dtype == torch.float32, \
f"atten_sink_npu dim num is {atten_sink_npu.dim()}, atten_sink_npu dtype is {atten_sink_npu.dtype}, \
expected 1, torch.float32"
assert cmp_sparse_indices_npu.dim() == 2, f"cmp_sparse_indices_npu dim num is {cmp_sparse_indices_npu.dim()}, \
expected 2"
@allow_in_graph
def npu_sparse_compress_flash_attention(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu, ori_block_table_npu,
cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio):
assert not isinstance(query_npu, FakeTensor), f"query_npu is FakeTensor"
check_input_output_shape_dtype(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu,
atten_sink_npu, cmp_sparse_indices_npu)
tile_config = SCFATileShapeConfig(
g_tile=64,
c1_tile_shape=[64, 64, 128, 512, 128, 128],
v1_tile_shape=[32, 640],
c2_tile_shape=[64, 64, 128, 640, 256, 256]
)
attention_out_npu = torch.zeros([query_npu.size(0), query_npu.size(1)], \
dtype=query_npu.dtype, device=f'{query_npu.device}')
nq = query_npu.size(0) // cmp_sparse_indices_npu.size(0)
n_kv = 1
topk = cmp_sparse_indices_npu.size(1)
block_size = 128
tensors = [query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu, ori_block_table_npu, cmp_block_table_npu,
atten_sink_npu, seqused_kv_npu, cmp_sparse_indices_npu, attention_out_npu]
sparse_compress_flash_attention_kernel(*tensors, nq, n_kv, softmax_scale, topk, block_size,
win_size, cmp_ratio, tile_config)
return attention_out_npu
pyptolib = torch.library.Library("pypto", "FRAGMENT")
pyptolib.define("sparse_compress_flash_attention(Tensor query_npu, Tensor q_act_seqs_npu, Tensor ori_kv_npu, \
Tensor cmp_kv_npu, Tensor ori_block_table_npu,\
Tensor cmp_block_table_npu, Tensor atten_sink_npu, Tensor seqused_kv_npu, Tensor cmp_sparse_indices_npu,\
float softmax_scale, int win_size, int cmp_ratio) -> (Tensor)")
@torch.library.impl(pyptolib, "sparse_compress_flash_attention", "Meta")
def sparse_compress_flash_attention(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu, ori_block_table_npu, \
cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio):
y = torch.empty([ori_block_table_npu.size(0), cmp_sparse_indices_npu.size(0) // ori_block_table_npu.size(0),\
query_npu.size(0) // cmp_sparse_indices_npu.size(0), query_npu.size(1)], \
dtype=query_npu.dtype, device=query_npu.device)
return y
try:
@torch.library.impl(pyptolib, "sparse_compress_flash_attention", "NPU")
def sparse_compress_flash_attention(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu, ori_block_table_npu,
cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio):
return npu_sparse_compress_flash_attention(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu,
ori_block_table_npu, cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio)
except Exception as e:
if "could not parse dispatch key: NPU" in str(e):
print(f"Skip: torchair not installed, skip NPU registration for operator 'sparse_compress_flash_attention'")
else:
print(f"Skip: Unexpected error : {e}")
def sparse_compress_flash_attention_graph(query_npu, q_act_seqs_npu, ori_kv_npu, cmp_kv_npu,
ori_block_table_npu, cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio):
return torch.ops.pypto.sparse_compress_flash_attention(query_npu, q_act_seqs_npu, ori_kv_npu,
cmp_kv_npu, ori_block_table_npu, cmp_block_table_npu, atten_sink_npu,
seqused_kv_npu, cmp_sparse_indices_npu, softmax_scale, win_size, cmp_ratio)