SparseFlashAttentionGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:根据topkIndices对key和value选取大小为selectedBlockSize的数据重排,接着进行训练场景下计算注意力的反向输出。

  • 计算公式:根据传入的topkIndice对keyIn和value选取数量为selectedBlockCount个大小为selectedBlockSize的数据重排,公式如下:

    selectedKey = Gather(key,topkIndices[i]), 0 <=i< selectBlockCount selectedKey\text{ }=\text{ }Gather \left( key,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.

    selectedValue = Gather(value,topkIndices[i]), 0 <=i< selectBlockCount selectedValue\text{ }=\text{ }Gather \left( value,topkIndices \left[ i \left] \left) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.

阶段1:根据矩阵乘法导数规则,计算dPdPdVdV:

dPt,:=dOt,:@VT dP\mathop{{}}\nolimits_{{t,:}}=dO\mathop{{}}\nolimits_{{t,:}}\text{@}V\mathop{{}}\nolimits^{{T}}

dV[u]=PTt,:@dOt,: dV \left[ u \left] =P\mathop{{}}\nolimits_{{T}}^{{t,:}}\text{@}dO\mathop{{}}\nolimits_{{t,:}}\right. \right.

阶段2:计算dSdS:

dSt,:=[Pt,:@(dPt,:−FlashSoftmaxGrad(dO,O))] d\mathop{{S}}\nolimits_{{t,:}}= \left[ P\mathop{{}}\nolimits_{{t,:}}@ \left( dP\mathop{{}}\nolimits_{{t,:}}-FlashSoftmaxGrad \left( dO,O \left) \left) \right] \right. \right. \right. \right.

阶段3:计算dQdQdKdK:

dQt,:=dSt,:@K[u]:t,:/dk,: d\mathop{{Q}}\nolimits_{{t,:}}=d\mathop{{S}}\nolimits_{{t,:}}@K \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}/\sqrt{{d\mathop{{}}\nolimits_{{k,:}}}}\right. \right.

dK[u]:t,:=dSt,:tT@Q/dt,: dK \left[ u \left] \mathop{{}}\nolimits_{{:t,:}}=dS\mathop{{}}\nolimits_{{t,:t}}\mathop{{}}\nolimits^{{T}}\text{@}Q/\sqrt{{d\mathop{{}}\nolimits_{{t,:}}}}\right. \right.

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 attention结构的输入Q。 BFLOAT16、FLOAT16 ND
key 输入 attention结构的输入K。 BFLOAT16、FLOAT16 ND
value 输入 attention结构的输入v。 BFLOAT16、FLOAT16 ND
sparseIndices 输入 稀疏场景下选择的权重较高的注意力索引。 INT32 ND
dOut 输入 注意力输出矩阵的梯度。 BFLOAT16、FLOAT16 ND
out 输入 注意力输出矩阵。 BFLOAT16、FLOAT16 ND
softmaxMax 输入 注意力正向计算的中间输出。 FLOAT32 ND
softmaxSum 输入 注意力正向计算的中间输出。 FLOAT32 ND
actualSeqLengthsQueryOptional 输入 每个Batch中,Query的有效token数。 INT32 ND
actualSeqLengthskvOptional 输入 每个Batch中,Key、value的有效token数。 INT32 ND
queryRopeOptional 输入 MLA rope部分:Query位置编码的输出。 BFLOAT16、FLOAT16 ND
keyRopeOptional 输入 MLA rope部分:Key位置编码的输出。 BFLOAT16、FLOAT16 ND
scaleValue 属性 缩放系数。 FLOAT32 -
sparseBlockSize 属性 选择的块的大小。 INT64 -
layout 属性 layout格式。 STRING -
sparseMode 属性 sparse的模式。 INT64 -
preTokens 属性 Attention算子里, 对S矩阵的滑窗起始位置。 INT64 -
nextTokens 属性 Attention算子里, 对S矩阵的滑窗终止位置。 INT64 -
deterministic 属性 确定性计算。 BOOL -
dQuery 输出 表示query的梯度。 BFLOAT16、FLOAT16 ND
dKey 输出 表示key的梯度。 BFLOAT16、FLOAT16 ND
dValue 输出 表示value的梯度。 BFLOAT16、FLOAT16 ND
dQueryRopeOptional 输出 表示queryRope的梯度。 BFLOAT16、FLOAT16 ND
dKeyRopeOptional 输出 表示keyRope的梯度。 BFLOAT16、FLOAT16 ND

约束说明

  • 参数query中的D和key、value的D值相等为512,参数query_rope中的Dr和key_rope的Dr值相等为64。
  • 参数query、key、value的数据类型必须保持一致。
  • 当前只支持value和key完全一致的场景。
  • 当前仅支持sparseMode=0或3(无mask或以右顶点为划分的下三角场景)
  • 仅支持BSND或TND layout;关于数据shape的约束如下:
    • B:取值范围1~256。
    • S1、S2:1~128K;S1、S2支持不等长。
    • N1支持1/2/4/8/16/32/64/128。
      • Ascend 950PR/Ascend 950DT:
        • 额外还支持48、24、12、6、3。
    • N2:仅支持1。
    • D:仅支持512。
    • Drope:仅支持64。
    • topk:1024、2048、3072、4096、5120、6144、7168、8192。
      • 不建议topk * sparseBlockSize超过100k,由于内部算法硬件限制可能会导致oom。
  • 确定性计算:
    • SparseFlashAttentionGrad默认非确定性实现,支持通过aclrtCtxSetSysParamOpt开启确定性。

调用示例

调用方式 样例代码 说明
aclnn接口 test_aclnn_sparse_flash_attention_grad 通过 aclnnSparseFlashAttentionGrad 接口方式调用算子