KvQuantSparseAttnSharedkv
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列加速卡产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
API功能:KvQuantSparseAttnSharedKv 算子旨在完成以下公式描述的Attention计算,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention:
-
计算公式:
O=softmax(Q@K~T⋅softmax_scale)@V~O = \text{softmax}(Q@\tilde{K}^T \cdot \text{softmax\_scale})@\tilde{V}
其中K~=V~\tilde{K}=\tilde{V}为基于入参控制的实际参与计算的KVKV。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| q | 输入 | 公式中的QQ,不支持非连续,layout_q为BSND时shape为[B, S1, N1, D],当layout_q为TND时shape为[T1, N1, D]。 | BFLOAT16 | ND |
| kv_quant_mode | 属性 | K、V nope的量化模式,仅支持1,表示K、V nope为per_tile量化,量化后的KV数据类型为FLOAT8_E4M3FN或HIFLOAT8 | INT32 | - |
| ori_kv | 可选输入 | 公式中的K~\tilde{K}和V~\tilde{V}的一部分,为原始不经压缩的KV,支持非连续,layout_kv为PA_ND时shape为[block_num1, block_size1, KV_N, D] | FLOAT8_E4M3FN或HIFLOAT8 | ND |
| cmp_kv | 可选输入 | 公式中的K~\tilde{K}和V~\tilde{V}的一部分,为经过压缩的KV,支持非连续,layout_kv为PA_ND时shape为[block_num2, block_size2, KV_N, D] | FLOAT8_E4M3FN或HIFLOAT8 | ND |
| ori_sparse_indices | 可选输入 | 预留参数,当前不生效,代表离散取oriKvCache的索引,不支持非连续,layout_q为BSND时shape为[B, Q_S, KV_N, K1],layout_q为TND时shape为[T1, KV_N, K1] | INT32 | ND |
| cmp_sparse_indices | 可选输入 | 代表离散取cmpKvCache的索引,不支持非连续,layout_q为BSND时shape为[B, Q_S, KV_N, K2],layout_q为TND时shape为[T1, KV_N, K2] | INT32 | ND |
| ori_block_table | 可选输入 | PageAttention中oriKvCache存储使用的block映射表,shape约束见下方约束说明 | INT32 | ND |
| cmp_block_table | 可选输入 | PageAttention中cmpKvCache存储使用的block映射表,shape约束见下方约束说明 | INT32 | ND |
| cu_seqlens_q | 可选输入 | 表示当前Batch及前序Batch中q的有效token数的累加和,维度为B+1,仅layout_q为TND场景需传入 | INT32 | ND |
| cu_seqlens_ori_kv | 可选输入 | 预留参数,当前不生效,表示当前Batch及前序Batch中ori_kv的有效token数的累加和,维度为B+1,仅layout_kv为TND场景需传入 | INT32 | ND |
| cu_seqlens_cmp_kv | 可选输入 | 预留参数,当前不生效,表示当前Batch及前序Batch中cmp_kv的有效token数的累加和,维度为B+1,仅layout_kv为TND场景需传入 | INT32 | ND |
| seqused_q | 可选输入 | 预留参数,当前不生效,表示不同Batch中q的有效token数,维度为B | INT32 | ND |
| seqused_kv | 可选输入 | 表示不同Batch中ori_kv的有效token数,维度为B | INT32 | ND |
| sinks | 可选输入 | 注意力下沉tensor,当前必须传入 | FLOAT32 | ND |
| metadata | 可选输入 | aicpu算子(npu_kv_quant_sparse_attn_sharedkv_metadata)的分核结果,shape固定为[1024] | INT32 | ND |
| tile_size | 可选属性 | 表示量化粒度,必须满足nope_head_dim能被tile_size整除,默认值为None,当前仅支持64 | INT32 | - |
| rope_head_dim | 可选属性 | 默认值为None,当前仅支持64 | INT32 | - |
| softmax_scale | 可选属性 | 默认值为None,当前为必传,代表缩放系数,作为q与ori_kv和cmp_kv矩阵乘后Muls的scalar值 | FLOAT32 | - |
| cmp_ratio | 可选属性 | 表示对ori_kv的压缩率,默认值为None,kv压缩场景支持4/128,非kv压缩场景仅支持传1 | INT32 | - |
| ori_mask_mode | 可选属性 | 表示q和ori_kv计算的mask模式,仅支持输入默认值4,代表band模式的mask | INT32 | - |
| cmp_mask_mode | 可选属性 | 表示q和cmp_kv计算的mask模式,仅支持输入默认值3,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景 | INT32 | - |
| ori_win_left | 可选属性 | 表示q和ori_kv计算中q对过去token计算的数量,仅支持默认值127 | INT32 | - |
| ori_win_right | 可选属性 | 表示q和ori_kv计算中q对未来token计算的数量,仅支持默认值0 | INT32 | - |
| layout_q | 可选属性 | 用于标识输入q的数据排布格式,支持BSND和TND,默认值为BSND | STRING | - |
| layout_kv | 可选属性 | 用于标识输入ori_kv和cmp_kv的数据排布格式,仅支持传入默认值PA_ND(PageAttention) | STRING | - |
| return_softmax_lse | 可选属性 | 预留参数,当前暂不支持,表示是否返回softmax_lse。True表示返回,False表示不返回,默认值为False | BOOL | - |
| attention_out | 输出 | 当layout_q为BSND时shape为[B, S1, N1, D],当layout_q为TND时shape为[T1, N1, D] | BFLOAT16 | ND |
| softmax_lse | 输出 | 输出q乘k的结果先取max得到softmax_max,q乘k的结果减去softmax_max,再取exp,最后取sum,得到softmax_sum,最后对softmax_sum取log,再加上softmax_max得到的结果。当layout_q为BSND时shape为[B, N2, S1, N1/N2],当layout_q为TND时shape为[N2, T1, N1/N2]。目前softmax_lse输出为无效值 | FLOAT32 | ND |
约束说明
- 该接口支持推理场景下使用。
- 该接口支持aclgraph模式。
- 该接口当前支持三种计算场景:场景一,仅传入ori_kv时为Sliding Window Attention计算;场景二,传入ori_kv及cmp_kv时为Sliding Window Attention + Compressed Attention计算;场景三,传入ori_kv、cmp_kv及cmp_sparse_indices时为Sliding Window Attention + Sparse Compressed Attention计算。
- 参数q中的D仅支持512。ori_kv、cmp_kv的D值仅支持640,按kv_rope、kv_nope及nope_quant_scale顺序拼接,尾部pad 128B对齐至640。其中kv_rope数据类型为
bfloat16,rope_head_dim为64;kv_nope数据类型为float8_e4m3fn或hifloat8,nope_head_dim为448;nope_quant_scale数据类型为float8_e8m0fnu,nope_quant_scale_dim = nope_head_dim / tile_size = 7,整体封装为float8_e4m3fn或hifloat8。 - 参数ori_kv、cmp_kv的数据类型必须保持一致。
- 参数q中的N1当前支持64/128,ori_kv、cmp_kv中的KV_N仅支持1。
- 参数ori_kv和cmp_kv中的block_size1和block_size2需为16的倍数,最大支持1024;block_num1及block_num2为PageAttention时block总数。
- 参数ori_sparse_indices与cmp_sparse_indices中的K1与K2为一次离散选取的block数,需要保证每行有效值均在前半部分,无效值均在后半部分,当前不支持传入ori_sparse_indices,cmp_sparse_indices中K2仅支持512/1024。
- 参数cu_seqlens_q、cu_seqlens_ori_kv及cu_seqlens_cmp_kv维度为B + 1,要求其值为当前Batch与前序Batch有效token数的累加值,后一个元素的值必须大于等于前一个元素的值。
- 参数seqused_q及seqused_kv维度为B,要求其值表示每个Batch中的有效token数。
- 参数ori_block_table的shape为2维,其中第一维长度为B,第二维长度不小于所有Batch中最大的S2对应的block数量,即S2_max / block_size1向上取整。
- 参数cmp_block_table的shape为2维,其中第一维长度为B,第二维长度不小于floor(S2_max / cmp_ratio)对应的block数量,即floor(S2_max / cmp_ratio) / block_size2向上取整。
- ori_mask_mode及cmp_mask_mode所表示的mask模式的详细介绍见sparse_mode参数说明。
- q、ori_kv、cmp_kv参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Hidden Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。
- Q_S和S1表示q shape中的S,S2表示ori_kv shape中的S,Q_N和N1表示num_q_heads,KV_N和N2表示num_ori_kv_heads和num_cmp_kv_heads;T1表示q shape中的T。