SparseLightningIndexerGradKLLoss
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
算子功能
-
算子功能:SparselightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,存放在SparseIndices中,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。
-
计算公式: 用于取Top-k的value的计算公式可以表示为:
It,:=Wt,:@ReLU(qt,:@(K:t,:)T)I_{t,:}=W_{t,:}@ReLU(q_{t,:}@(K_{:t,:})^T)
其中,WW是第tt个token对应的weights,qq是第tt个token对应的GG个query头合轴后的矩阵,KK为tt行KK矩阵。
LightningIndexer会单独训练,对应的loss function为:
L(I)=∑tDKL(pt,:∣∣Softmax(It,:))L(I){=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:}))
其中,pp是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。DKLD_{KL}为KL散度,其表达式为:
DKL(a∣∣b)=∑iailog(aibi)D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)}
通过求导可得Loss的梯度表达式:
dIt,:=Softmax(It,:)−pt,:dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right.
利用链式法则可以进行weights,query和key矩阵的梯度计算:
dWt,:=dIt,:@(ReLU(St,:))TdW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{{T}}\right. \right. \right. \right.
dqt,:=dSt,:@K:t,:d\mathop{{q}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@K\mathop{{}}\nolimits_{{:t,:}}
dK:t,:=(dSt,:)T@q:t,:dK\mathop{{}}\nolimits_{{:t,:}}= \left( dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{{T}}@q\mathop{{}}\nolimits_{{:t,:}}\right. \right.
其中,S为QK矩阵softmax的结果。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | attention结构的输入Q。 | FLOAT16、BFLOAT16 | ND |
| key | 输入 | attention结构的输入K。 | FLOAT16、BFLOAT16 | ND |
| queryIndex | 输入 | lightingIndexer结构的输入queryIndex。 | FLOAT16、BFLOAT16 | ND |
| keyIndex | 输入 | lightingIndexer结构的输入keyIndex。 | FLOAT16、BFLOAT16 | ND |
| weights | 输入 | 权重。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| sparseIndices | 输入 | topk_index,用来选择每个query对应的key和value。 | INT32 | ND |
| softmaxMax | 输入 | 注意力正向计算的中间输出。 | FLOAT32 | ND |
| softmaxSum | 输入 | 注意力正向计算的中间输出。 | FLOAT32 | ND |
| queryRope | 输入 | MLA rope部分:Query位置编码的输出。 | FLOAT16、BFLOAT16 | ND |
| keyRope | 输入 | MLA rope部分:Key位置编码的输出。 | FLOAT16、BFLOAT16 | ND |
| actualSeqLengthsQuery | 输入 | 每个Batch中,Query的有效token数。 | INT64 | ND |
| actualSeqLengthsKey | 输入 | 每个Batch中,Key的有效token数。 | INT64 | ND |
| scaleValue | 输入 | 缩放系数。 | DOUBLE | - |
| layout | 输入 | layout格式。 | STRING | - |
| sparseMode | 输入 | sparse的模式。 | INT64 | - |
| deterministic | 输入 | 确定性计算。 | BOOL | - |
| dQueryIndex | 输出 | QueryIndex的梯度。 | FLOAT16、BFLOAT16 | ND |
| dKeyIndex | 输出 | KeyIndex的梯度。 | FLOAT16、BFLOAT16 | ND |
| dWeights | 输出 | Weights的梯度。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| loss | 输出 | 损失函数值。 | FLOAT32 | ND |
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:
- T1支持大于等于actualSeqLengthsQuery的累加和,T2支持大于等于actualSeqLengthsKey的累加和。
约束说明
-
确定性计算:
- SparseLightningIndexerGradKLLoss默认非确定性实现,不支持通过aclrtCtxSetSysParamOpt开启确定性。
-
公共约束
- 参数query、key、queryIndex、keyIndex的数据类型应保持一致。
- 参数weights不为float32时,参数query、key、queryIndex、keyIndex、weights的数据类型应保持一致。
- 入参为空的场景处理:
- query为空Tensor:直接返回。
- 公共约束里入参为空的场景和FAG保持一致。
sparseMode 含义 备注 0 defaultMask模式,如果attenmask未传入则不做mask操作,忽略preTokens和nextTokens;如果传入,则需要传入完整的attenmask矩阵,表示preTokens和nextTokens之间的部分需要计算 不支持 1 allMask,必须传入完整的attenmask矩阵 不支持 2 leftUpCausal模式的mask,需要传入优化后的attenmask矩阵 不支持 3 rightDownCausal模式的mask,对应以右顶点为划分的下三角场景,需要传入优化后的attenmask矩阵 支持 4 band模式的mask,需要传入优化后的attenmask矩阵 不支持 5 prefix 不支持 6 global 不支持 7 dilated 不支持 8 block_local 不支持 -
规格约束
规格项 规格 规格说明 deterministic bool Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品支持确定性计算
Ascend 950PR/Ascend 950DT支持确定性计算B Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品支持1~256
Ascend 950PR/Ascend 950DT支持1~128- S1、S2 S1支持1~8K,S2支持1~128K S1、S2支持不等长 N1 32、64、128 SparseFA为MQA Nidx1 Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品支持8、16、32、64
Ascend 950PR/Ascend 950DT支持32、64SparseFA为MQA N2 1 SparseFA为MQA,Nidx2=1 Nidx2 1 SparseFA为MQA,N2=1 D 512 query与query_index的D不同 Drope 64 - K 1024、2048、3072、4096、5120、6144、7168、8192 - layout BSND/TND - Ascend 950PR/Ascend 950DT:B仅支持1~128,N1额外支持48,Nidx1额外支持24。
-
典型值
规格项 典型值 query N1=128/64; D =512 queryIndex Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品支持N1 = 64/32/16/8; D = 128 ; S1 = 64k/128k
Ascend 950PR/Ascend 950DT支持N1 = 64/32;D = 128 ; S1 = 64k/128kkeyIndex D = 128 topk topk = 1024/2048/3072/4096/5120/6144/7168/8192 qRope d= 64
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_sparse_lightning_indexer_grad_kl_loss | 通过aclnnSparseLightningIndexerGradKLLoss接口方式调用SparseLightningIndexerGradKLLoss算子。 |