KvQuantSparseAttnSharedkv

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列加速卡产品 ×
Atlas 训练系列产品 ×

功能说明

  • API功能:KvQuantSparseAttnSharedKv 算子旨在完成以下公式描述的Attention计算,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention:

  • 计算公式:

    O=softmax(Q@K~T⋅softmax_scale)@V~O = \text{softmax}(Q@\tilde{K}^T \cdot \text{softmax\_scale})@\tilde{V}

    其中K~=V~\tilde{K}=\tilde{V}为基于入参控制的实际参与计算的KVKV

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
q 输入 公式中的QQ,不支持非连续,layout_q为BSND时shape为[B, S1, N1, D],当layout_q为TND时shape为[T1, N1, D]。 BFLOAT16 ND
kv_quant_mode 属性 K、V nope的量化模式,仅支持1,表示K、V nope为per_tile量化,量化后的KV数据类型为FLOAT8_E4M3FN或HIFLOAT8 INT32 -
ori_kv 可选输入 公式中的K~\tilde{K}V~\tilde{V}的一部分,为原始不经压缩的KV,支持非连续,layout_kv为PA_ND时shape为[block_num1, block_size1, KV_N, D] FLOAT8_E4M3FN或HIFLOAT8 ND
cmp_kv 可选输入 公式中的K~\tilde{K}V~\tilde{V}的一部分,为经过压缩的KV,支持非连续,layout_kv为PA_ND时shape为[block_num2, block_size2, KV_N, D] FLOAT8_E4M3FN或HIFLOAT8 ND
ori_sparse_indices 可选输入 预留参数,当前不生效,代表离散取oriKvCache的索引,不支持非连续,layout_q为BSND时shape为[B, Q_S, KV_N, K1],layout_q为TND时shape为[T1, KV_N, K1] INT32 ND
cmp_sparse_indices 可选输入 代表离散取cmpKvCache的索引,不支持非连续,layout_q为BSND时shape为[B, Q_S, KV_N, K2],layout_q为TND时shape为[T1, KV_N, K2] INT32 ND
ori_block_table 可选输入 PageAttention中oriKvCache存储使用的block映射表,shape约束见下方约束说明 INT32 ND
cmp_block_table 可选输入 PageAttention中cmpKvCache存储使用的block映射表,shape约束见下方约束说明 INT32 ND
cu_seqlens_q 可选输入 表示当前Batch及前序Batch中q的有效token数的累加和,维度为B+1,仅layout_q为TND场景需传入 INT32 ND
cu_seqlens_ori_kv 可选输入 预留参数,当前不生效,表示当前Batch及前序Batch中ori_kv的有效token数的累加和,维度为B+1,仅layout_kv为TND场景需传入 INT32 ND
cu_seqlens_cmp_kv 可选输入 预留参数,当前不生效,表示当前Batch及前序Batch中cmp_kv的有效token数的累加和,维度为B+1,仅layout_kv为TND场景需传入 INT32 ND
seqused_q 可选输入 预留参数,当前不生效,表示不同Batch中q的有效token数,维度为B INT32 ND
seqused_kv 可选输入 表示不同Batch中ori_kv的有效token数,维度为B INT32 ND
sinks 可选输入 注意力下沉tensor,当前必须传入 FLOAT32 ND
metadata 可选输入 aicpu算子(npu_kv_quant_sparse_attn_sharedkv_metadata)的分核结果,shape固定为[1024] INT32 ND
tile_size 可选属性 表示量化粒度,必须满足nope_head_dim能被tile_size整除,默认值为None,当前仅支持64 INT32 -
rope_head_dim 可选属性 默认值为None,当前仅支持64 INT32 -
softmax_scale 可选属性 默认值为None,当前为必传,代表缩放系数,作为q与ori_kv和cmp_kv矩阵乘后Muls的scalar值 FLOAT32 -
cmp_ratio 可选属性 表示对ori_kv的压缩率,默认值为None,kv压缩场景支持4/128,非kv压缩场景仅支持传1 INT32 -
ori_mask_mode 可选属性 表示q和ori_kv计算的mask模式,仅支持输入默认值4,代表band模式的mask INT32 -
cmp_mask_mode 可选属性 表示q和cmp_kv计算的mask模式,仅支持输入默认值3,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景 INT32 -
ori_win_left 可选属性 表示q和ori_kv计算中q对过去token计算的数量,仅支持默认值127 INT32 -
ori_win_right 可选属性 表示q和ori_kv计算中q对未来token计算的数量,仅支持默认值0 INT32 -
layout_q 可选属性 用于标识输入q的数据排布格式,支持BSND和TND,默认值为BSND STRING -
layout_kv 可选属性 用于标识输入ori_kv和cmp_kv的数据排布格式,仅支持传入默认值PA_ND(PageAttention) STRING -
return_softmax_lse 可选属性 预留参数,当前暂不支持,表示是否返回softmax_lse。True表示返回,False表示不返回,默认值为False BOOL -
attention_out 输出 当layout_q为BSND时shape为[B, S1, N1, D],当layout_q为TND时shape为[T1, N1, D] BFLOAT16 ND
softmax_lse 输出 输出q乘k的结果先取max得到softmax_max,q乘k的结果减去softmax_max,再取exp,最后取sum,得到softmax_sum,最后对softmax_sum取log,再加上softmax_max得到的结果。当layout_q为BSND时shape为[B, N2, S1, N1/N2],当layout_q为TND时shape为[N2, T1, N1/N2]。目前softmax_lse输出为无效值 FLOAT32 ND

约束说明

  • 该接口支持推理场景下使用。
  • 该接口支持aclgraph模式。
  • 该接口当前支持三种计算场景:场景一,仅传入ori_kv时为Sliding Window Attention计算;场景二,传入ori_kv及cmp_kv时为Sliding Window Attention + Compressed Attention计算;场景三,传入ori_kv、cmp_kv及cmp_sparse_indices时为Sliding Window Attention + Sparse Compressed Attention计算。
  • 参数q中的D仅支持512。ori_kv、cmp_kv的D值仅支持640,按kv_rope、kv_nope及nope_quant_scale顺序拼接,尾部pad 128B对齐至640。其中kv_rope数据类型为bfloat16,rope_head_dim为64;kv_nope数据类型为float8_e4m3fnhifloat8,nope_head_dim为448;nope_quant_scale数据类型为float8_e8m0fnu,nope_quant_scale_dim = nope_head_dim / tile_size = 7,整体封装为float8_e4m3fnhifloat8
  • 参数ori_kv、cmp_kv的数据类型必须保持一致。
  • 参数q中的N1当前支持64/128,ori_kv、cmp_kv中的KV_N仅支持1。
  • 参数ori_kv和cmp_kv中的block_size1和block_size2需为16的倍数,最大支持1024;block_num1及block_num2为PageAttention时block总数。
  • 参数ori_sparse_indices与cmp_sparse_indices中的K1与K2为一次离散选取的block数,需要保证每行有效值均在前半部分,无效值均在后半部分,当前不支持传入ori_sparse_indices,cmp_sparse_indices中K2仅支持512/1024。
  • 参数cu_seqlens_q、cu_seqlens_ori_kv及cu_seqlens_cmp_kv维度为B + 1,要求其值为当前Batch与前序Batch有效token数的累加值,后一个元素的值必须大于等于前一个元素的值。
  • 参数seqused_q及seqused_kv维度为B,要求其值表示每个Batch中的有效token数。
  • 参数ori_block_table的shape为2维,其中第一维长度为B,第二维长度不小于所有Batch中最大的S2对应的block数量,即S2_max / block_size1向上取整。
  • 参数cmp_block_table的shape为2维,其中第一维长度为B,第二维长度不小于floor(S2_max / cmp_ratio)对应的block数量,即floor(S2_max / cmp_ratio) / block_size2向上取整。
  • ori_mask_mode及cmp_mask_mode所表示的mask模式的详细介绍见sparse_mode参数说明
  • q、ori_kv、cmp_kv参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Hidden Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
  • Q_S和S1表示q shape中的S,S2表示ori_kv shape中的S,Q_N和N1表示num_q_heads,KV_N和N2表示num_ori_kv_heads和num_cmp_kv_heads;T1表示q shape中的T。