DenseLightningIndexerSoftmaxLse
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品 | √ |
| Atlas A2 训练系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:DenseLightningIndexerSoftmaxLse算子是DenseLightningIndexerGradKlLoss算子计算Softmax输入的一个分支算子。
-
计算公式:
res=AttentionMask(ReduceSum(W⊙ReLU(Qindex@KindexT)))\text{res}=\text{AttentionMask}\left(\text{ReduceSum}\left(W\odot\text{ReLU}\left(Q_{index}@K_{index}^T\right)\right)\right)
maxIndex=max(res)\text{maxIndex}=\text{max}\left(res\right)
sumIndex=ReduceSum(exp(res−maxIndex))\text{sumIndex}=\text{ReduceSum}\left(\text{exp}\left(res-maxIndex\right)\right)
maxIndex,sumIndex作为输出传递给算子DenseLightningIndexerGradKlLoss作为输入计算Softmax使用。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| queryIndex | 输入 | lightningIndexer结构的输入queryIndex。 | FLOAT16、BFLOAT16 | ND |
| keyIndex | 输入 | lightningIndexer结构的输入keyIndex。 | FLOAT16、BFLOAT16 | ND |
| weights | 输入 | 权重。 | FLOAT16、BFLOAT16、FLOAT32 | ND |
| actualSeqLengthsQuery | 输入 | 每个Batch中,Query的有效token数。 | INT64 | ND |
| actualSeqLengthsKey | 输入 | 每个Batch中,Key的有效token数。 | INT64 | ND |
| layout | 输入 | layout格式。 | - | - |
| sparseMode | 输入 | sparse的模式。 | INT64 | - |
| preTokens | 输入 | 用于稀疏计算,表示Attention需要和前几个token计算关联。 | INT64 | - |
| nextTokens | 输入 | 用于稀疏计算,表示Attention需要和后几个token计算关联。 | INT64 | - |
| softmaxMaxOut | 输出 | softmax计算使用的max值。 | FLOAT32 | ND |
| softmaxSumOut | 输出 | softmax计算使用的sum值。 | FLOAT32 | ND |
约束说明
无
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_dense_lightning_indexer_softmax_lse | 通过aclnnDenseLightningIndexerSoftmaxLse接口方式调用dense_lightning_indexer_softmax_lse算子。 |