FusedFloydAttention
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | × |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:训练场景下,使用FloydAttention算法实现多维自注意力的计算。
-
计算公式:
注意力的正向计算公式如下:
weights=Softmax(attenMask+scale∗(einsum(query,key1T)+einsum(query,key2T)))weights = Softmax(attenMask + scale*(einsum(query, key1^T) + einsum(query, key2^T)))
attention_out=einsum(weights,value1)+einsum(weights,value2)attention\_out = einsum(weights, value1) + einsum(weights, value2)
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | 公式中的输入query。 | BFLOAT16、FLOAT16 | ND |
| key1 | 输入 | 公式中的输入key1。 | BFLOAT16、FLOAT16 | ND |
| value1 | 输入 | 公式中的输入value1。 | BFLOAT16、FLOAT16 | ND |
| attenMaskOptional | 可选输入 | 公式中的atten_mask,表示注意力掩码,取值为1代表该位不参与计算(不生效),为0代表该位参与计算。 | BOOL、UINT8 | ND |
| scaleValue | 可选属性 |
|
DOUBLE | - |
| softmaxMaxOut | 输出 | Softmax计算的Max中间结果,用于反向计算。 | FLOAT | ND |
| softmaxSumOut | 输出 | Softmax计算的Sum中间结果,用于反向计算。 | FLOAT | ND |
| attentionOut | 输出 | 公式中的attention_out。 | BFLOAT16、FLOAT16 | ND |
约束说明
-
该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配
-
关于数据shape的约束,其中:
- B:取值范围为1~2K。
- H:取值范围为1~256。
- N:取值范围为16~1M且N%16==0。
- M:取值范围为128~1M且M%128==0。
- K:取值范围为128~1M且K%128==0。
- D:取值范围为32/64/128。
-
query与key1的第0/2/4轴需相同。
-
key1与value1 shape需相同。
-
key2与value2 shape需相同。
-
softmaxMax与softmaxSum shape需相同。
-
D只支持32/64/128。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_fused_floyd_attention | 通过aclnnFusedFloydAttention接口方式调用FusedFloydAttention算子。 |