FusedFloydAttentionGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:训练场景下,计算Floyd注意力的反向输出,FloydAttn相较于传统FA主要是计算qk/pv注意力时会额外将seq作为batch轴从而转换为batchMatmul。

  • 计算公式:

    已知注意力的正向计算公式为:

    P=Softmax(Mask(scale∗(Q∗K1T+Q∗K2T),atten_mask))Y=(P∗V1+P∗V2)P=Softmax(Mask(scale*(Q*K_1^T + Q*K_2^T), atten\_mask)) \\ Y=(P*V_1+P*V_2)

    则注意力的反向计算公式为:

    S=Softmax(S)S=Softmax(S)

    dV1=PTdYdV_1=P^TdY

    dV2=PTdYdV_2=P^TdY

    dQ=((dS)∗K1)d+((dS)∗K2)ddQ=\frac{((dS)*K_1)}{\sqrt{d}}+\frac{((dS)*K_2)}{\sqrt{d}}

    dK1=((dS)T∗Q)ddK_1=\frac{((dS)^T*Q)}{\sqrt{d}}

    dK2=((dS)T∗Q)ddK_2=\frac{((dS)^T*Q)}{\sqrt{d}}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 公式中的输入Q。 FLOAT16、BFLOAT16 ND
key1 输入 公式中的输入K1。 FLOAT16、BFLOAT16 ND
value1 输入 公式中的输入V1。 FLOAT16、BFLOAT16 ND
key2 输入 公式中的输入K2。 FLOAT16、BFLOAT16 ND
value2 输入 公式中的输入V2。 FLOAT16、BFLOAT16 ND
dy 输入 公式中的输入dY。 FLOAT16、BFLOAT16 ND
attenMaskOptional 可选输入 公式中的atten_mask,表示注意力掩码,取值为1代表该位不参与计算(不生效),为0代表该位参与计算。 BOOL、UINT8 ND
scaleValue 可选属性
  • 公式中的scale,表示缩放系数,作为计算流中Muls的scalar值。
  • 默认值为1.0。
DOUBLE -
dqOut 输出 公式中的dQ,表示query的梯度。 FLOAT16、BFLOAT16 ND
dk1Out 输出 公式中的dK1,表示key1的梯度。 FLOAT16、BFLOAT16 ND
dv1Out 输出 公式中的dV1,表示value1的梯度。 FLOAT16、BFLOAT16 ND
dk2Out 输出 公式中的dK2,表示key2的梯度。 FLOAT16、BFLOAT16 ND
dv2Out 输出 公式中的dV2,表示value2的梯度。 FLOAT16、BFLOAT16 ND

约束说明

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配

  • 关于数据shape的约束,其中:

    • B:取值范围为1~2K。
    • H:取值范围为1~256。
    • N:取值范围为16~1M且N%16==0。
    • M:取值范围为128~1M且M%128==0。
    • K:取值范围为128~1M且K%128==0。
    • D:取值范围为32/64/128。
  • query与key1的第0/2/4轴需相同。

  • key1与value1 shape需相同。

  • key2与value2 shape需相同。

  • query与dy/attentionIn shape需相同。

  • softmaxMax与softmaxSum shape需相同。

  • D只支持32/64/128。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_fused_floyd_attention_grad 通过接口方式调用aclnnFusedFloydAttentionGrad算子。