| feat: flash attn refactor, support BSND, BNSD, 1TND, 1NTD layout to improve performance
Co-authored-by: liyingxuan<liyingxuan3@huawei.com>
# message auto-generated for no-merge-commit merge:
!2577 merge flash_attn_refactor_v2 into master
feat: flash attn refactor, support BSND, BNSD, 1TND, 1NTD layout to improve performance
Created-by: liyx616
Commit-by: liyingxuan
Merged-by: ascend-robot
Description: 原PR:(检视意见均已处理)
https://gitcode.com/Ascend/MindSpeed-MM/pull/2466
UT的diff较多,所以重新提交了一个pr,删除ut中对NON_MEGATRON环境变量的设置,目前是发现所有的UT会共用一个调用栈,如果设置该环境变量,会导致MEGATRON后端的UT进入mindspeed_mm/fsdp/ops的patch中,对看护的范围有影响,所以删除该设置。
## What this PR does / why we need it?
1. 重构了一下 flash attn公共接口的代码,这个接口当时是patch了transformers库中的attention_interface接口,这个接口原来的输入是1NTD,输出是1TND,在NPU的分支中,FA都是用tnd计算的,性能较差,而且为了保证接口的一致性,会插入一些额外的转置,这里进行了重构
2. 修正apply_transformers_attention_patch中的笔误
## Does this PR introduce any user-facing change?
对使用者来说没有影响,不直接暴露这个接口的用法,
对开发者来说,如果是直接迁移transformers的代码,input_layout字段会默认设置为1NTD和transformers保持一致
如果在某些场景下,追求卓越的性能,可以修改输入的qkv的layout,并将转置之后输入attention_interface接口的layout传入input_layout字段
具体的内部接口变更说明如下:
- 【新增】input_layout(默认为1NTD,和transformers默认的保持一致):输入的layout,可以在外部的rope或者其他模块中,少做一次transpose,传入不同的layout,目前是支持BNSD,BSND,1TND和1NTD四种,其他的layout会直接raise error
- 【等效】cu_seq_lens_q & cu_seq_lens_k:packing形式下需要传入不同的序列是如何打包的,这个其实本来就已经是transformers原生接口的入参了,只是被放在kwargs里面了,这里将其展开,因为我会根据cu_seq_lens_q和kv的长度判断,是否可以退化为bnsd or bsnd
- 【变更】ring_fa_layout变更为ring_in_bnsd(默认为False):因为mindspeed core提供的do_ring_attention的接口输入的要求是sbh或者tnd,如果要用bnsd计算的话,输入也得是sbh,所以这里其实很容易引起歧义,同时因为新增了输入的layout,如果支持所有的ring_fa_layout将会有很多的分支需要写,但是其实大部分没有必要支持,比如说输入是bnsd,那就没必要适配tnd的ring了,如果输入时tnd,那也没必要适配sbh和bnsd的ring了。所以我在代码中会自动根据输入的layout判断使用哪个ring layout。ring_in_bnsd表示在do_ring_attention内部,是否需要使用bnsd来计算fa(增加转置,但是长序列下性能会更好)
映射关系如下:(注意,仅在npu的flash_attention_2分支下生效,gpu和sdpa在不支持的操作下,会显式raise error)
| input_layout<br>(默认为1NTD) | cu_seqlens | 实际计算layout | CP分支逻辑 | 输出 |
| :--- | :--- | :--- | :--- | :--- |
| 1NTD | None <br> or len(cu_seqlens) <=2 <br> (ring会结合seq_split_seqlens是否为none判断) | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | 转置为bsnd输出 |
| 1NTD | len(cu_seqlens) >=2 <br> (ring会结合seq_split_seqlens是否为none判断) | TND(所以需要先转置+squeeze成TND) | Ulysses:tnd<br>ring/hybrid:TND | unsqueeze为bsnd输出 |
| 1TND | None <br> or len(cu_seqlens) <=2 <br> (ring会结合seq_split_seqlens是否为none判断) | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | bsnd |
| 1TND | len(cu_seqlens) >=2 <br> (ring会结合seq_split_seqlens是否为none判断) | TND(需要先squeeze成TND) | Ulysses:tnd<br>ring/hybrid:TND | unsqueeze为bsnd输出 |
| BSND | / | BSND | ulysses:bsnd<br>ring / hybrid: 转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | bsnd |
| BNSD | / | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bnsd | bnsd |
## How was this patch tested?
已包含UT:ut/fsdp/ops/flash_attn.py
存量UT整改:
删除ut中对NON_MEGATRON环境变量的设置,目前是发现所有的UT会共用一个调用栈,如果设置该环境变量,会导致MEGATRON后端的UT进入mindspeed_mm/fsdp/ops的patch中,对看护的范围有影响,所以删除该设置。
ut/models/common/test_attention.py 删除,当前该UT是复用了其他UT创建的通信组,如果添加新的UT,会改变UT执行的顺序,导致该通信组不可用,pytest卡死,尝试了一些修改均不可用,考虑到这部分代码目前几乎没有人使用了,FSDP2分支的Flashattention看护会更重要,所以先把这个UT下掉了
已长跑验证qwen3vl和qwen3.5的精度
遗留:packing场景 + causal mask + ring attention 当前未跑通,因为存量模型不涉及这个case,因此当前先直接raise error
See merge request: Ascend/MindSpeed-MM!2577 | 5 天前 |