SparseFlashMlaGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:计算
SparseFlashMla训练场景下注意力的反向输出,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention。 -
计算公式:
阶段一:根据不同cmp_ratio场景,对输入ori_kv与cmp_kv进行选择
- 当cmp_ratio = 1 (SWA):
selectedKv = orikvselectedKv\text{ }=\text{ }orikv
- 当cmp_ratio = 4 (SCFA):
selectedKv =concat(oriKv, Gather(cmpkv,topkIndices[i])), 0 <=i< selectBlockCountselectedKv\text{ }=concat(oriKv, \text{ }Gather \left( cmpkv,topkIndices \left[ i \left] \left)) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.
- else (CFA):
selectedKv =concat(oriKv, cmpkv)selectedKv\text{ }=concat(oriKv, \text{ }cmpkv)
阶段二:计算P、dP、dS
P=SimpleSoftmax(Mask(Q @ selectedKvT⋅ scale),lse)P = SimpleSoftmax(Mask(Q \text{ }@\text{ } selectedKv^{{T}} \cdot \text{ } scale), lse)
dP=dO @ selectedKvTdP = dO \text{ }@\text{ } selectedKv^{{T}}
dS=P×(dP − SoftmaxGrad(dO,O))dS = P \times (dP\text{ } -\text{ } SoftmaxGrad(dO, O))
阶段三:计算dQ, dKV, dSinks
dQ=dS @ selectedKv ⋅ scaledQ = dS \text{ } @ \text{ } selectedKv \text{ } \cdot \text{ } scale
dKV=dST @ Q ⋅ scale+PT@ dOdKV = dS^{{T}} \text{ } @ \text{ } Q \text{ } \cdot \text{ } scale + P^{{T}} @ \text{ } dO
dSinks=ReduceSum(−P × dP × SimpleSoftmax(sinks,lse),dim=−1)dSinks = ReduceSum(-P \text{ }\times\text{ } dP \text{ }\times\text{ } SimpleSoftmax(sinks, lse), dim=-1)
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | attention结构的输入Q。 | BFLOAT16、FLOAT16 | ND |
| oriKvOptional | 输入 | attention结构的原始输入K(V)。 | BFLOAT16、FLOAT16 | ND |
| cmpKvOptional | 输入 | 经过Compressor压缩后的K(V)。 | BFLOAT16、FLOAT16 | ND |
| dOut | 输入 | 注意力输出矩阵的梯度。 | BFLOAT16、FLOAT16 | ND |
| out | 输入 | 注意力输出矩阵。 | BFLOAT16、FLOAT16 | ND |
| lse | 输入 | 注意力正向计算的输出lse,计算公式详见正向文档。 | FLOAT32 | ND |
| oriSparseIndicesOptional | 输入 | 稀疏场景下选择的oriKvOptional中权重较高的注意力索引。 | INT32 | ND |
| cmpSparseIndicesOptional | 输入 | 稀疏场景下选择的cmpKvOptional中权重较高的注意力索引。 | INT32 | ND |
| cuSeqlensQOptional | 输入 | 每个Batch中,Query的有效token数。 | INT32 | ND |
| cuSeqlensOriKvOptional | 输入 | 每个Batch中,oriKvOptional的有效token数。 | INT32 | ND |
| cuSeqlensCmpKvOptional | 输入 | 每个Batch中,cmpKvOptional的有效token数。 | INT32 | ND |
| sequsedQOptional | 输入 | 表示不同batch中query实际参与运算的token数。 | INT32 | ND |
| sequsedOriKvOptional | 输入 | 表示不同batch中oriKvOptional实际参与运算的token数。 | INT32 | ND |
| sequsedCmpKvOptional | 输入 | 表示不同batch中cmpKvOptional实际参与运算的token数。 | INT32 | ND |
| cmpResidualKvOptional | 输入 | 表示每个batch S2 // cmpRatio后的余数。 | INT32 | ND |
| oriTopkLengthOptional | 输入 | 表示每行query对应的oriKvOptional实际可选的topk长度。 | INT32 | ND |
| cmpTopkLengthOptional | 输入 | 表示每行query对应的cmpKvOptional实际可选的topk长度。 | INT32 | ND |
| sinksOptional | 输入 | 注意力下沉tensor。 | FLOAT32 | ND |
| metadataOptional | 输入 | 表示tiling下沉的aicpu算子输出结果。 | INT32 | ND |
| scaleValue | 属性 | 缩放系数。 | FLOAT32 | - |
| cmpRatio | 属性 | 表示对oriKvOptional的压缩率。 | INT64 | - |
| oriMaskMode | 属性 | 表示query和oriKvOptional计算的mask模式。 | INT64 | - |
| cmpMaskMode | 属性 | 表示query和cmpKvOptional计算的mask模式。 | INT64 | - |
| oriWinLeft | 属性 | 表示query和oriKvOptional计算中query对过去token计算的数量。 | INT64 | - |
| oriWinRight | 属性 | 表示query和oriKvOptional计算中query对未来token计算的数量。 | INT64 | - |
| layoutQOptional | 属性 | 表示输入query的数据排布格式。 | STRING | - |
| layoutKvOptional | 属性 | 表示输入ori_kv和cmp_kv的数据排布格式。 | STRING | - |
| deterministic | 属性 | 表示是否开启确定性,应和全局保持一致。 | INT64 | - |
| dQueryOut | 输出 | 表示query的梯度。 | BFLOAT16、FLOAT16 | ND |
| dOriKvOutOptional | 输出 | 表示oriKvOptional的梯度。 | BFLOAT16、FLOAT16 | ND |
| dCmpKvOptional | 输出 | 表示cmpKvOptional的梯度。 | BFLOAT16、FLOAT16 | ND |
| dSinksOutOptional | 输出 | 表示sinksOptional的梯度。 | FLOAT32 | ND |
| oriSoftmaxL1NormOptional | 输出 | 表示query与oriKvOptional计算得出的softmax结果。 | FLOAT32 | ND |
| cmpSoftmaxL1NormOptional | 输出 | 表示query与cmpKvOptional计算得出的softmax结果。 | FLOAT32 | ND |
约束说明
- 仅支持BSND或TND layout;关于数据shape的约束如下:
- B:泛化支持。
- S1、S2:泛化支持;S1、S2支持不等长。
- N1:支持1~128。
- N2:仅支持1。
- D:仅支持512。
调用示例
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_sparse_flash_mla_grad | 通过 aclnnSparseFlashMlaGrad 接口方式调用算子 |