PromptFlashAttention
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列加速卡产品 | √ |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:全量推理场景的FlashAttention算子,支持sparse优化、支持actualSeqLengthsKv优化、支持INT8量化功能,支持高精度或者高性能模式选择。
-
计算公式:
self-attention(自注意力)利用输入样本自身的关系构建了一种注意力模型。其原理是假设有一个长度为nn的输入样本序列xx,xx的每个元素都是一个dd维向量,可以将每个dd维向量看作一个token embedding,将这样一条序列经过3个权重矩阵变换得到3个维度为n∗dn*d的矩阵。
self-attention的计算公式一般定义如下,其中QQ、KK、VV为输入样本的重要属性元素,是输入样本经过空间变换得到,且可以统一到一个特征空间中。公式及算子名称中的"Attention"为"self-attention"的简写。
Attention(Q,K,V)=Score(Q,K)VAttention(Q,K,V)=Score(Q,K)V
本算子中Score函数采用Softmax函数,self-attention计算公式为:
Attention(Q,K,V)=Softmax(QKTd)VAttention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d}})V
其中:QQ和KTK^T的乘积代表输入xx的注意力,为避免该值变得过大,通常除以dd的开根号进行缩放,并对每行进行softmax归一化,与VV相乘后得到一个n∗dn*d的矩阵。
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | 公式中的输入Q。 | FLOAT16、BFLOAT16、INT8 | ND |
| key | 输入 | 公式中的输入K。 | FLOAT16、BFLOAT16、INT8 | ND |
| value | 输入 | 公式中的输入V。 | FLOAT16、BFLOAT16、INT8 | ND |
| attentionOut | 输出 | 公式中的输出。 | FLOAT16、BFLOAT16、INT8 | ND |
- Atlas A2 训练系列产品/Atlas A2 推理系列产品、Ascend 950PR/Ascend 950DT:数据类型支持FLOAT16、BFLOAT16、INT8。
- Atlas 推理系列加速卡产品:仅支持FLOAT16。
约束说明
-
该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。
-
入参为空的处理:算子内部需要判断参数query是否为空,如果是空则直接返回。参数query不为空Tensor,参数key、value为空tensor,则attentionOut填充为全零。attentionOut为空Tensor时,AscendCLNN框架会处理。其余在上述参数说明中标注了“可传入nullptr”的入参为空指针时,不进行处理。
-
query,key,value输入,功能使用限制如下:
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品、Ascend 950PR/Ascend 950DT:
-
支持B轴小于等于65536(64k),输入类型包含INT8时D轴非32对齐或输入类型为FLOAT16或BFLOAT16时D轴非16对齐时,B轴仅支持到128。
-
支持N轴小于等于256。
-
S支持小于等于20971520(20M)。部分长序列场景下,如果计算量过大可能会导致pfa算子执行超时(aicore error类型报错,errorStr为:timeout or trap error),此场景下建议做S切分处理,注:这里计算量会受B、S、N、D等的影响,值越大计算量越大。典型的会超时的长序列(即B、S、N、D的乘积较大)场景包括但不限于:
B Q_N Q_S D KV_N KV_S 1 20 2097152 256 1 2097152 1 2 20971520 256 2 20971520 20 1 2097152 256 1 2097152 1 10 2097152 512 1 2097152 -
支持D轴小于等于512。inputLayout为BSH或者BSND时,要求N*D小于65535。
-
-
Atlas A2 训练系列产品/Atlas A2 推理系列产品:在TND场景下query,key,value输入的综合限制:
- T小于等于65536。
- N等于8/16/32/64/128,且Q_N、K_N、V_N相等。
- Q_D、K_D等于192,V_D等于128/192。
- 数据类型仅支持BFLOAT16。
- sparse模式仅支持sparse=0且不传mask,或sparse=3且传入mask。
- 当sparse=3时,要求每个batch单独的actualSeqLengths < actualSeqLengthsKv。
-
Atlas 推理系列加速卡产品:
- 在inputLayout为BSH时,支持B轴小于等于300,其余情况B轴小于等于128;支持N轴小于等于256;支持S轴小于等于65535(64k), Q_S或KV_S非128对齐,Q_S和KV_S不等长的场景不支持配置atten_mask;支持D轴小于等于512。
-
-
当inputLayout为BNSD_BSND时,输入query的shape是BNSD,输出attentionOut的shape为BSND;其余情况attentionOut的shape需要与入参query的shape保持一致。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_PromptFlashAttentionV3 | 通过aclnnPromptFlashAttentionV3调用PromptFlashAttentionV3算子 |