AttentionUpdate

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:将各SP域PA算子的输出的中间结果lse,localOut两个局部变量结果更新成全局结果。
  • 计算公式:输入lseilse_iOiO_i、输出OO

lsemax=maxlseilse_{max} = \text{max}lse_i

lse=∑iexp(lsei−lsemax)lse = \sum_i \text{exp}(lse_i - lse_{max})

lsem=lsemax+log(lse)lse_m = lse_{max} + \text{log}(lse)

O=∑iOi⋅exp(lsei−lsem)O = \sum_i O_i \cdot \text{exp}(lse_i - lse_m)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
lsei 输入 各SP域的局部lse。 FLOAT32 ND
Oi 输入 各SP域的局部attentionout。 FLOAT32,FLOAT16,BFLOAT16 ND
lsem 输出 更新后的全局lse。 FLOAT32 ND
O 输入 更新后的全局attentionout。 FLOAT32,FLOAT16,BFLOAT16 ND

约束说明

  • Atlas A2 训练系列产品/Atlas A2 推理系列产品、Atlas A3 训练系列产品/Atlas A3 推理系列产品:支持FLOAT32、FLOAT16、BFLOAT16的Oi和O。
  • Ascend 950PR/Ascend 950DT:支持FLOAT32、FLOAT16、BFLOAT16的Oi和O,且Oi和O数据类型相同。
  • 序列并行的并行度sp取值范围[1, 16]。
  • headDim取值范围[8, 512]且是8的倍数。
  • 不支持非连续的Tensor。
  • 支持空Tensor。

调用说明

调用方式 样例代码 说明
aclnn接口 test_aclnn_attention_update 通过aclnnAttentionUpdate接口方式调用AttentionUpdate算子。