SparseFlashAttentionGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:根据topkIndices对key和value选取大小为selectedBlockSize的数据重排,接着进行训练场景下计算注意力的反向输出。
-
计算公式:根据传入的topkIndice对keyIn和value选取数量为selectedBlockCount个大小为selectedBlockSize的数据重排,公式如下:
selectedKey = Gather(key,topkIndices[i]), 0 <=i< selectBlockCount selectedKey\text{ }=\text{ }Gather \left( key,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.
selectedValue = Gather(value,topkIndices[i]), 0 <=i< selectBlockCount selectedValue\text{ }=\text{ }Gather \left( value,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.
阶段1:根据矩阵乘法导数规则,计算dPdP和dVdV:
dPt,:=dOt,:@VT dP\mathop{{}}\nolimits_{{t,:}}=dO\mathop{{}}\nolimits_{{t,:}}\text{@}V\mathop{{}}\nolimits^{{T}}
dV[u]=PTt,:@dOt,: dV \left[ u \left] =P\mathop{{}}\nolimits_{{T}}^{{t,:}}\text{@}dO\mathop{{}}\nolimits_{{t,:}}\right. \right.
阶段2:计算dSdS:
dSt,:=[Pt,:@(dPt,:−FlashSoftmaxGrad(dO,O))] d\mathop{{S}}\nolimits_{{t,:}}= \left[ P\mathop{{}}\nolimits_{{t,:}}@ \left( dP\mathop{{}}\nolimits_{{t,:}}-FlashSoftmaxGrad \left( dO,O \left) \left) \right] \right. \right. \right. \right.
阶段3:计算dQdQ与dKdK:
dQt,:=dSt,:@K[u]:t,:/dk,: d\mathop{{Q}}\nolimits_{{t,:}}=d\mathop{{S}}\nolimits_{{t,:}}@K \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}/\sqrt{{d\mathop{{}}\nolimits_{{k,:}}}}\right. \right.
dK[u]:t,:=dSt,:tT@Q/dt,: dK \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}=dS\mathop{{}}\nolimits_{{t,:t}}\mathop{{}}\nolimits^{{T}}\text{@}Q/\sqrt{{d\mathop{{}}\nolimits_{{t,:}}}}\right. \right.
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | attention结构的输入Q。 | BFLOAT16、FLOAT16 | ND |
| key | 输入 | attention结构的输入K。 | BFLOAT16、FLOAT16 | ND |
| value | 输入 | attention结构的输入v。 | BFLOAT16、FLOAT16 | ND |
| sparseIndices | 输入 | 稀疏场景下选择的权重较高的注意力索引。 | INT32 | ND |
| dOut | 输入 | 注意力输出矩阵的梯度。 | BFLOAT16、FLOAT16 | ND |
| out | 输入 | 注意力输出矩阵。 | BFLOAT16、FLOAT16 | ND |
| softmaxMax | 输入 | 注意力正向计算的中间输出。 | FLOAT32 | ND |
| softmaxSum | 输入 | 注意力正向计算的中间输出。 | FLOAT32 | ND |
| actualSeqLengthsQueryOptional | 输入 | 每个Batch中,Query的有效token数。 | INT32 | ND |
| actualSeqLengthskvOptional | 输入 | 每个Batch中,Key、value的有效token数。 | INT32 | ND |
| queryRopeOptional | 输入 | MLA rope部分:Query位置编码的输出。 | BFLOAT16、FLOAT16 | ND |
| keyRopeOptional | 输入 | MLA rope部分:Key位置编码的输出。 | BFLOAT16、FLOAT16 | ND |
| scaleValue | 属性 | 缩放系数。 | FLOAT32 | - |
| sparseBlockSize | 属性 | 选择的块的大小。 | INT64 | - |
| layout | 属性 | layout格式。 | STRING | - |
| sparseMode | 属性 | sparse的模式。 | INT64 | - |
| preTokens | 属性 | Attention算子里, 对S矩阵的滑窗起始位置。 | INT64 | - |
| nextTokens | 属性 | Attention算子里, 对S矩阵的滑窗终止位置。 | INT64 | - |
| deterministic | 属性 | 确定性计算。 | BOOL | - |
| dQuery | 输出 | 表示query的梯度。 | BFLOAT16、FLOAT16 | ND |
| dKey | 输出 | 表示key的梯度。 | BFLOAT16、FLOAT16 | ND |
| dValue | 输出 | 表示value的梯度。 | BFLOAT16、FLOAT16 | ND |
| dQueryRopeOptional | 输出 | 表示queryRope的梯度。 | BFLOAT16、FLOAT16 | ND |
| dKeyRopeOptional | 输出 | 表示keyRope的梯度。 | BFLOAT16、FLOAT16 | ND |
约束说明
- 参数query中的D和key、value的D值相等为512,参数query_rope中的Dr和key_rope的Dr值相等为64。
- 参数query、key、value的数据类型必须保持一致。
- 当前只支持value和key完全一致的场景。
- 当前仅支持sparseMode=0或3(无mask或以右顶点为划分的下三角场景)
- 仅支持BSND或TND layout;关于数据shape的约束如下:
- B:取值范围1~256。
- S1、S2:1~128K;S1、S2支持不等长。
- N1支持1/2/4/8/16/32/64/128。
- Ascend 950PR/Ascend 950DT:
- 额外还支持48、24、12、6、3。
- Ascend 950PR/Ascend 950DT:
- N2:仅支持1。
- D:仅支持512。
- Drope:仅支持64。
- topk:1024、2048、3072、4096、5120、6144、7168、8192。
- 不建议topk * sparseBlockSize超过100k,由于内部算法硬件限制可能会导致oom。
- 确定性计算:
- SparseFlashAttentionGrad默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。
调用示例
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_sparse_flash_attention_grad | 通过 aclnnSparseFlashAttentionGrad 接口方式调用算子 |