DenseLightningIndexerGradKLLoss
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。
-
计算公式:
- Top-k value的计算公式:
It,:=Wt,:@ReLU(q~t,:@K~:t,:⊤)I_{t,:}=W_{t,:}@ReLU(\tilde{q}_{t,:}@\tilde{K}_{:t,:}^\top)
- Wt,:W_{t,:}是第tt个token对应的weightsweights;
- q~t,:\tilde{q}_{t,:}是q~\tilde{q}矩阵第tt个token对应的GG个query头合轴后的结果;
- K~:t,:\tilde{K}_{:t,:}为tt行K~\tilde{K}矩阵。
- 正向的Softmax对应公式:
pt,:=Softmax(qt,:@K:t,:⊤/d)p_{t,:} = \text{Softmax}(q_{t,:} @ K_{:t,:}^\top/\sqrt{d})
- pt,:p_{t,:}是第tt个token对应的Softmax结果;
- qt,:q_{t,:}是qq矩阵第tt个token对应的GG个query头合轴后的结果;
- ${K}_{:t,:}$为tt行KK矩阵。
- npu_lightning_indexer会单独训练,对应的loss function为:
Loss=∑tDKL(pt,:∣∣Softmax(It,:))Loss{=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))
其中,pt,:p_{t,:}是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:
DKL(a∣∣b)=∑iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}
- 通过求导可得Loss的梯度表达式:
dIt,:=Softmax(It,:)−pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.
利用链式法则可以进行weights,query和key矩阵的梯度计算:
dWt,:=dIt,:@(ReLU(St,:))⊤dW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right.
dq~t,:=dSt,:@K~:t,:d\mathop{{\tilde{q}}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@\tilde{K}\mathop{{}}\nolimits_{{:t,:}}
dK~:t,:=(dSt,:)⊤@q~:t,:d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}=\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}@\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right.
其中,SS为q~\tilde{q}和KK矩阵乘的结果。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | attention结构的输入Q。 | FLOAT16、BFLOAT16 | ND |
| key | 输入 | attention结构的输入K。 | FLOAT16、BFLOAT16 | ND |
| queryIndex | 输入 | lightningIndexer结构的输入queryIndex。 | FLOAT16、BFLOAT16 | ND |
| keyIndex | 输入 | lightningIndexer结构的输入keyIndex。 | FLOAT16、BFLOAT16 | ND |
| weights | 输入 | 权重。 | FLOAT16、BFLOAT16 | ND |
| softmaxMax | 输入 | Device侧的aclTensor,注意力正向计算的中间输出。 | FLOAT32 | ND |
| softmaxSum | 输入 | Device侧的aclTensor,注意力正向计算的中间输出。 | FLOAT32 | ND |
| softmaxMaxIndex | 输入 | Device侧的aclTensor,注意力正向计算的中间输出。 | FLOAT32 | ND |
| softmaxSumIndex | 输入 | Device侧的aclTensor,注意力正向计算的中间输出。 | FLOAT32 | ND |
| queryRope | 输入 | MLA rope部分:Query位置编码的输出。 | FLOAT16、BFLOAT16 | ND |
| keyRope | 输入 | MLA rope部分:Key位置编码的输出。 | FLOAT16、BFLOAT16 | ND |
| actualSeqLengthsQuery | 输入 | 每个Batch中,Query的有效token数。 | INT64 | ND |
| actualSeqLengthsKey | 输入 | 每个Batch中,Key的有效token数。 | INT64 | ND |
| scaleValue | 输入 | 缩放系数。 | double | - |
| layout | 输入 | layout格式。 | char* | - |
| sparseMode | 输入 | sparse的模式。 | INT64 | - |
| preTokens | 输入 | 用于稀疏计算,表示Attention需要和前几个token计算关联。 | INT64 | - |
| nextTokens | 输入 | 用于稀疏计算,表示Attention需要和后几个token计算关联。 | INT64 | - |
| dQueryIndex | 输出 | QueryIndex的梯度。 | FLOAT16、BFLOAT16 | ND |
| dKeyIndex | 输出 | KeyIndex的梯度。 | FLOAT16、BFLOAT16 | ND |
| dWeights | 输出 | Weights的梯度。 | FLOAT16、BFLOAT16 | ND |
| loss | 输出 | 损失函数值。 | FLOAT32 | ND |
约束说明
无
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_dense_lightning_indexer_grad_kl_loss | 通过aclnnDenseLightningIndexerGradKLLoss接口方式调用dense_lightning_indexer_grad_kl_loss算子。 |