SparseFlashMlaGrad

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:计算SparseFlashMla训练场景下注意力的反向输出,支持Sliding Window Attention、Compressed Attention以及Sparse Compressed Attention。

  • 计算公式:

    阶段一:根据不同cmp_ratio场景,对输入ori_kv与cmp_kv进行选择

    • 当cmp_ratio = 1 (SWA):

    selectedKv = orikvselectedKv\text{ }=\text{ }orikv

    • 当cmp_ratio = 4 (SCFA):

    selectedKv =concat(oriKv, Gather(cmpkv,topkIndices[i])), 0 <=i< selectBlockCountselectedKv\text{ }=concat(oriKv, \text{ }Gather \left( cmpkv,topkIndices \left[ i \left] \left)) ,\text{ }0\text{ } < =i < \text{ }selectBlockCount\right. \right. \right. \right.

    • else (CFA):

    selectedKv =concat(oriKv, cmpkv)selectedKv\text{ }=concat(oriKv, \text{ }cmpkv)

    阶段二:计算P、dP、dS

    P=SimpleSoftmax(Mask(Q @ selectedKvT⋅ scale),lse)P = SimpleSoftmax(Mask(Q \text{ }@\text{ } selectedKv^{{T}} \cdot \text{ } scale), lse)

    dP=dO @ selectedKvTdP = dO \text{ }@\text{ } selectedKv^{{T}}

    dS=P×(dP − SoftmaxGrad(dO,O))dS = P \times (dP\text{ } -\text{ } SoftmaxGrad(dO, O))

    阶段三:计算dQ, dKV, dSinks

    dQ=dS @ selectedKv ⋅ scaledQ = dS \text{ } @ \text{ } selectedKv \text{ } \cdot \text{ } scale

    dKV=dST @ Q ⋅ scale+PT@ dOdKV = dS^{{T}} \text{ } @ \text{ } Q \text{ } \cdot \text{ } scale + P^{{T}} @ \text{ } dO

    dSinks=ReduceSum(−P × dP × SimpleSoftmax(sinks,lse),dim=−1)dSinks = ReduceSum(-P \text{ }\times\text{ } dP \text{ }\times\text{ } SimpleSoftmax(sinks, lse), dim=-1)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 attention结构的输入Q。 BFLOAT16、FLOAT16 ND
oriKvOptional 输入 attention结构的原始输入K(V)。 BFLOAT16、FLOAT16 ND
cmpKvOptional 输入 经过Compressor压缩后的K(V)。 BFLOAT16、FLOAT16 ND
dOut 输入 注意力输出矩阵的梯度。 BFLOAT16、FLOAT16 ND
out 输入 注意力输出矩阵。 BFLOAT16、FLOAT16 ND
lse 输入 注意力正向计算的输出lse,计算公式详见正向文档。 FLOAT32 ND
oriSparseIndicesOptional 输入 稀疏场景下选择的oriKvOptional中权重较高的注意力索引。 INT32 ND
cmpSparseIndicesOptional 输入 稀疏场景下选择的cmpKvOptional中权重较高的注意力索引。 INT32 ND
cuSeqlensQOptional 输入 每个Batch中,Query的有效token数。 INT32 ND
cuSeqlensOriKvOptional 输入 每个Batch中,oriKvOptional的有效token数。 INT32 ND
cuSeqlensCmpKvOptional 输入 每个Batch中,cmpKvOptional的有效token数。 INT32 ND
sequsedQOptional 输入 表示不同batch中query实际参与运算的token数。 INT32 ND
sequsedOriKvOptional 输入 表示不同batch中oriKvOptional实际参与运算的token数。 INT32 ND
sequsedCmpKvOptional 输入 表示不同batch中cmpKvOptional实际参与运算的token数。 INT32 ND
cmpResidualKvOptional 输入 表示每个batch S2 // cmpRatio后的余数。 INT32 ND
oriTopkLengthOptional 输入 表示每行query对应的oriKvOptional实际可选的topk长度。 INT32 ND
cmpTopkLengthOptional 输入 表示每行query对应的cmpKvOptional实际可选的topk长度。 INT32 ND
sinksOptional 输入 注意力下沉tensor。 FLOAT32 ND
metadataOptional 输入 表示tiling下沉的aicpu算子输出结果。 INT32 ND
scaleValue 属性 缩放系数。 FLOAT32 -
cmpRatio 属性 表示对oriKvOptional的压缩率。 INT64 -
oriMaskMode 属性 表示query和oriKvOptional计算的mask模式。 INT64 -
cmpMaskMode 属性 表示query和cmpKvOptional计算的mask模式。 INT64 -
oriWinLeft 属性 表示query和oriKvOptional计算中query对过去token计算的数量。 INT64 -
oriWinRight 属性 表示query和oriKvOptional计算中query对未来token计算的数量。 INT64 -
layoutQOptional 属性 表示输入query的数据排布格式。 STRING -
layoutKvOptional 属性 表示输入ori_kv和cmp_kv的数据排布格式。 STRING -
deterministic 属性 表示是否开启确定性,应和全局保持一致。 INT64 -
dQueryOut 输出 表示query的梯度。 BFLOAT16、FLOAT16 ND
dOriKvOutOptional 输出 表示oriKvOptional的梯度。 BFLOAT16、FLOAT16 ND
dCmpKvOptional 输出 表示cmpKvOptional的梯度。 BFLOAT16、FLOAT16 ND
dSinksOutOptional 输出 表示sinksOptional的梯度。 FLOAT32 ND
oriSoftmaxL1NormOptional 输出 表示query与oriKvOptional计算得出的softmax结果。 FLOAT32 ND
cmpSoftmaxL1NormOptional 输出 表示query与cmpKvOptional计算得出的softmax结果。 FLOAT32 ND

约束说明

  • 仅支持BSND或TND layout;关于数据shape的约束如下:
    • B:泛化支持。
    • S1、S2:泛化支持;S1、S2支持不等长。
    • N1:支持1~128。
    • N2:仅支持1。
    • D:仅支持512。

调用示例

调用方式 样例代码 说明
aclnn接口 test_aclnn_sparse_flash_mla_grad 通过 aclnnSparseFlashMlaGrad 接口方式调用算子