DenseLightningIndexerGradKLLoss

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。

  • 计算公式:

    1. Top-k value的计算公式:

    It,:=Wt,:@ReLU(q~t,:@K~:t,:⊤)I_{t,:}=W_{t,:}@ReLU(\tilde{q}_{t,:}@\tilde{K}_{:t,:}^\top)

    • Wt,:W_{t,:}是第tt个token对应的weightsweights
    • q~t,:\tilde{q}_{t,:}q~\tilde{q}矩阵第tt个token对应的GG个query头合轴后的结果;
    • K~:t,:\tilde{K}_{:t,:}ttK~\tilde{K}矩阵。
    1. 正向的Softmax对应公式:

    pt,:=Softmax(qt,:@K:t,:⊤/d)p_{t,:} = \text{Softmax}(q_{t,:} @ K_{:t,:}^\top/\sqrt{d})

    • pt,:p_{t,:}是第tt个token对应的Softmax结果;
    • qt,:q_{t,:}qq矩阵第tt个token对应的GG个query头合轴后的结果;
    • ${K}_{:t,:}$为ttKK矩阵。
    1. npu_lightning_indexer会单独训练,对应的loss function为:

    Loss=∑tDKL(pt,:∣∣Softmax(It,:))Loss{=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))

    其中,pt,:p_{t,:}是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:

    DKL(a∣∣b)=∑iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}

    1. 通过求导可得Loss的梯度表达式:

    dIt,:=Softmax(It,:)−pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.

    利用链式法则可以进行weights,query和key矩阵的梯度计算:

    dWt,:=dIt,:@(ReLU(St,:))⊤dW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right.

    dq~t,:=dSt,:@K~:t,:d\mathop{{\tilde{q}}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@\tilde{K}\mathop{{}}\nolimits_{{:t,:}}

    dK~:t,:=(dSt,:)⊤@q~:t,:d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}=\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}@\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right.

    其中,SSq~\tilde{q}KK矩阵乘的结果。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 attention结构的输入Q。 FLOAT16、BFLOAT16 ND
key 输入 attention结构的输入K。 FLOAT16、BFLOAT16 ND
queryIndex 输入 lightningIndexer结构的输入queryIndex。 FLOAT16、BFLOAT16 ND
keyIndex 输入 lightningIndexer结构的输入keyIndex。 FLOAT16、BFLOAT16 ND
weights 输入 权重。 FLOAT16、BFLOAT16 ND
softmaxMax 输入 Device侧的aclTensor,注意力正向计算的中间输出。 FLOAT32 ND
softmaxSum 输入 Device侧的aclTensor,注意力正向计算的中间输出。 FLOAT32 ND
softmaxMaxIndex 输入 Device侧的aclTensor,注意力正向计算的中间输出。 FLOAT32 ND
softmaxSumIndex 输入 Device侧的aclTensor,注意力正向计算的中间输出。 FLOAT32 ND
queryRope 输入 MLA rope部分:Query位置编码的输出。 FLOAT16、BFLOAT16 ND
keyRope 输入 MLA rope部分:Key位置编码的输出。 FLOAT16、BFLOAT16 ND
actualSeqLengthsQuery 输入 每个Batch中,Query的有效token数。 INT64 ND
actualSeqLengthsKey 输入 每个Batch中,Key的有效token数。 INT64 ND
scaleValue 输入 缩放系数。 double -
layout 输入 layout格式。 char* -
sparseMode 输入 sparse的模式。 INT64 -
preTokens 输入 用于稀疏计算,表示Attention需要和前几个token计算关联。 INT64 -
nextTokens 输入 用于稀疏计算,表示Attention需要和后几个token计算关联。 INT64 -
dQueryIndex 输出 QueryIndex的梯度。 FLOAT16、BFLOAT16 ND
dKeyIndex 输出 KeyIndex的梯度。 FLOAT16、BFLOAT16 ND
dWeights 输出 Weights的梯度。 FLOAT16、BFLOAT16 ND
loss 输出 损失函数值。 FLOAT32 ND

约束说明

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_dense_lightning_indexer_grad_kl_loss 通过aclnnDenseLightningIndexerGradKLLoss接口方式调用dense_lightning_indexer_grad_kl_loss算子。