Flash Attention Infer算子性能优化
写在前面
该文档主要说明,如何从算子结构和源码实现出发,对examples/49_ascend950_flash_attention_infer中的FlashAttention Infer(FAI)样例进行性能分析和调优。本文主要从以下几个角度进行FAI样例调优方案分析:
- FAI样例的主要性能瓶颈可能来自哪里。
- 哪些参数是编译期模板参数,哪些参数来自tiling或runtime输入。
- 调整TileShape、CV流水、分核策略、输入布局或Paged Attention参数时,会影响哪些片上资源和执行路径。
1. 样例介绍
该样例面向Flash Attention推理前向算子FlashAttentionInfer,计算流程为:
S = Q * K^T * scale + mask
P = softmax(S)
O = P * V
在源码实现中,该计算被拆成四个主要阶段:
| 阶段 | 执行单元 | 主要工作 | 数据流 |
|---|---|---|---|
| QK | AIC | 计算Q * K^T |
Q/K -> S |
| Online Softmax | AIV | 处理mask、max、sum、exp并生成P | S -> P |
| PV | AIC | 计算P * V |
P/V -> Otmp |
| RescaleO | AIV | 按online softmax状态归一化更新O | Otmp -> O |
FAI计算涉及多个CV操作,优化FAI算子的目标集中于让QK -> Online Softmax -> PV -> RescaleO这条链路在AIC/AIV之间尽量形成稳定流水,并在多核之间尽量避免长尾。
当前样例的几个关键实现点如下:
// examples/49_ascend950_flash_attention_infer/fai_tiling.h
constexpr int32_t BLOCK_BASE_SIZE = 128;
// examples/49_ascend950_flash_attention_infer/fai_kernel.h
using L1TileShape = tla::Shape<_128, _128, _128>;
using L0TileShape = L1TileShape;
// examples/49_ascend950_flash_attention_infer/fai_kernel_utils.h
constexpr uint32_t KERNEL_TASK_NUM = 3;
L1TileShape三个维度分别对应qSeqBase、kvSeqBase和embedBase。当前样例中L1TileShape::K要求与head dimension匹配,默认适用于headDim = 128的场景。
需要特别注意,BLOCK_BASE_SIZE和L1TileShape不是两个完全独立的参数。BLOCK_BASE_SIZE用于tiling阶段统计q/kv基本块和多核权重,L1TileShape用于kernel实际执行的tile形状。只修改其中一侧,可能导致tiling权重模型和kernel实际循环粒度不一致。
2. 优化方案
调整TileShape参数组合
TileShape决定每次QK、Softmax、PV和RescaleO处理的基本粒度。较大的qSeqBase或kvSeqBase通常可以减少外层循环次数、同步次数和小块搬运次数,但会增加UB、L1和L0资源占用,也可能在mask或尾块场景中引入更多无效计算。
调整前至少需要检查以下静态资源项:
MM1_RESULT_SIZE = qSeqBase / CV_RATIO * kvSeqBase * sizeof(ElementS);
MM2_RESULT_SIZE = qSeqBase / CV_RATIO * Align64(embedBase) * sizeof(ElementOTmp);
SHARE_UB_SIZE = CeilDiv(qSeqBase, NUM2) * sizeof(ElementS);
MM2_LEFT_SIZE = qSeqBase * kvSeqBase * sizeof(ElementP);
这些资源项主要影响:
| 资源 | 主要占用 |
|---|---|
| UB | bmm1TensorList[2]、bmm2TensorList[2]、sumUb、expUb、maxUb |
| L1 | mm2AL1TensorList[KERNEL_TASK_NUM] |
| L0A/L0B/L0C | QK/PV的BlockMmad tile |
| 同步事件 | C1/V1、V1/C2、C2/V2之间的流水同步 |
常见调整方向如下:
| 现象 | 调整方向 | 风险 |
|---|---|---|
| q/kv外层循环次数多,同步频繁 | 增大qSeqBase或kvSeqBase |
UB/L1/L0占用上升 |
| AIV Softmax或RescaleO频繁启动 | 优先增大qSeqBase |
q尾块浪费可能增加 |
| QK/PV小块matmul过多 | 优先增大kvSeqBase |
causal或稀疏mask下无效kv块可能增加 |
| causal mask尾块浪费明显 | 谨慎增大kvSeqBase,必要时减小kv tile |
循环和同步次数会上升 |
headDim != 128 |
同步修改L1TileShape::K并检查静态buffer |
BlockMmad和中间buffer约束可能不满足 |
如果只是把TileShape调大而不计算资源预算,优化很容易变成随机试错。更合理的做法是先根据目标shape列出候选qSeqBase/kvSeqBase/embedBase,确认UB/L1/L0资源可容纳,再做性能验证。
优化AIC/AIV协同流水
FAI中,AIC负责QK/PV矩阵乘,AIV负责Online Softmax和RescaleO。虽然同一个kvSeq块内部存在阶段依赖,但相邻kvSeq块之间可以形成流水重叠。理想steady-state形态为:
task t: AIC QK
task t+1: AIV OnlineSoftmax
task t+2: AIC PV
task t+3: AIV RescaleO
源码中与流水相关的关键对象包括:
| 对象 | 作用 |
|---|---|
taskId |
流水任务编号 |
runInfo[4] |
保存不同流水阶段的执行状态 |
bmm1TensorList[2] |
QK输出UB双缓冲 |
bmm2TensorList[2] |
PV输出UB双缓冲 |
mm2AL1TensorList[KERNEL_TASK_NUM] |
Softmax结果P的L1缓存 |
sumUb/maxUb/expUb[KERNEL_TASK_NUM] |
Online Softmax状态缓存 |
同步路径如下:
| 同步flag | 方向 | 含义 |
|---|---|---|
SYNC_C1_V1_FLAG |
AIC -> AIV | QK结果就绪,AIV可执行Softmax |
SYNC_V1_C2_FLAG |
AIV -> AIC | P已搬到L1,AIC可执行PV |
SYNC_C2_V2_FLAG |
AIC -> AIV | PV结果就绪,AIV可更新O |
样例中的关键数据通路是AIV将Softmax结果P从UB直接搬到AIC侧L1:
using CopyUbToL1P = Tile::CopyUb2L1Tla<ArchTag, decltype(vf1OutUb), TensorDst>;
该路径避免了UB -> GM -> L1中转,降低GM带宽压力,也减少PV启动前的等待。
- 如果AIC经常等待AIV,通常说明Vector侧成为瓶颈:应优先检查Softmax、mask读取、RescaleO、小
qSeqBase导致的频繁AIV调度,以及KERNEL_TASK_NUM是否不足以覆盖Vector延迟。 - 如果AIV经常等待AIC,通常说明Cube或搬运成为瓶颈:应优先检查QK/PV tile粒度、K/V布局连续性、MTE2搬运和Paged Attention跳读。
增加KERNEL_TASK_NUM可能提升流水覆盖能力,但它也会增加L1中的mm2AL1TensorList占用、UB中的Softmax状态占用、同步事件压力以及末尾flush阶段复杂度。因此应先确认TileShape和多核切分合理,再考虑调整流水深度。
优化多核负载均衡
Causal或稀疏mask场景下,不同qSeq块可见的kvSeq块数量不同。如果简单按qSeq行数平均分核,早期和后期q块的有效计算量可能差异很大,长尾核会决定整体耗时。
FAI在tiling阶段按有效block数做贪心切分,核心逻辑位于examples/49_ascend950_flash_attention_infer/fai_tiling.h:
| 函数或字段 | 作用 |
|---|---|
GetCalcBlockNumsOneHead |
统计单head下mask后的有效qSeq x kvSeq块数量 |
ComputeSplitNBSeq |
沿Batch/Head/qSeq三轴做贪心切分 |
bnAxisStartIdx |
记录每个AI Core的起始batch/head |
sparseStartIdx |
记录每个AI Core的起始qSeq block |
可以用以下抽象权重理解分核策略:
core_weight[i] = sum(valid_kv_blocks for qSeq blocks assigned to core i)
target_weight = total_valid_blocks / blockDim
imbalance = max(core_weight) / average(core_weight)
如果核间有效block数差距很大,应优先检查mask模式、actualSeqLengths、actualSeqLengthsKV和blockDim。如果有效任务数本身很少,使用过多AI Core只会增加同步和调度开销。Paged Attention场景下还要注意,block table跳读成本没有完全体现在有效block数量中;即使各核有效block数接近,也可能因为KV cache物理布局不同而出现耗时差异。
贪心分核也有局限:当单个qSeq block的权重大于目标core weight时,它无法继续拆分该block;当batch很小、qSeq很短或mask裁剪后有效块很少时,也无法填满所有core。这类场景下应考虑减少blockDim、调整qSeqBase,或重新设计更细粒度的切分策略。
优化输入布局、MTE2和FixPipe
当性能受搬运影响时,首先要确认业务侧输入物理布局与kernel TLA布局一致。当前样例中,Q/O按[batch * qSeqlen, qHeads * embed]方向连续,V按[batch * kvSeqlen, kvHeads * embed]方向连续,K使用ColumnMajor以适配Q * K^T访问。
MTE2相关问题通常来自:
- 上层实际layout与kernel TLA layout不一致。
- K/V在QK或PV访问路径上跨大stride。
- GQA场景中
qHeads / kvHeads分组不连续。 - mask张量不是按
[batch * qSeqlen, kvSeqlen]连续存放。 - Paged Attention下block table导致KV cache频繁跳读。
如果业务侧无法改变输入物理布局,可以考虑新增专用LayoutTag或TileCopy适配。但这类修改会同时影响QK、PV和mask访问路径,必须重新验证正确性和性能。
FixPipe相关问题通常出现在O写回路径,尤其是小embed、尾块多、O地址不对齐或workspace分段不对齐时。建议保证Q/K/V/O首地址尽量按512B对齐,连续维stride尽量满足搬运对齐要求。若后续引入GM workspace缓存中间结果,分配多个workspace分段时应使用类似RoundUp(offset, 512)的方式保证每段起始地址对齐。
优化Paged Attention路径
启用PAGED_CACHE_FLAG后,kernel会按blockSize对kvSeqlen向上取整,并通过blockTables间接访问KV cache:
kvSeqlen = RoundUp(kvSeqlen, blockSize);
blockSize过小会减少尾块浪费,但会增加block table访问次数和KV跳读概率;blockSize过大有利于page内连续访问,但可能增加尾块无效计算。优化Paged Attention时,应优先保证KV cache page内部连续、page起始地址对齐、block table顺序符合kernel访问顺序,并确保actualSeqLengthsKV准确参与分核。
如果只有Paged Attention场景出现MTE2偏高或核间耗时差异,优先检查block table和KV cache物理布局,而不是直接缩小TileShape。TileShape只能改变单次计算粒度,不能消除由page组织方式造成的随机跳读。
3. 总结
融合算子场景可以参考此案例优化