文件最后提交记录最后更新时间
feat: flash attn refactor, support BSND, BNSD, 1TND, 1NTD layout to improve performance Co-authored-by: liyingxuan<liyingxuan3@huawei.com> # message auto-generated for no-merge-commit merge: !2577 merge flash_attn_refactor_v2 into master feat: flash attn refactor, support BSND, BNSD, 1TND, 1NTD layout to improve performance Created-by: liyx616 Commit-by: liyingxuan Merged-by: ascend-robot Description: 原PR:(检视意见均已处理) https://gitcode.com/Ascend/MindSpeed-MM/pull/2466 UT的diff较多,所以重新提交了一个pr,删除ut中对NON_MEGATRON环境变量的设置,目前是发现所有的UT会共用一个调用栈,如果设置该环境变量,会导致MEGATRON后端的UT进入mindspeed_mm/fsdp/ops的patch中,对看护的范围有影响,所以删除该设置。 ## What this PR does / why we need it? 1. 重构了一下 flash attn公共接口的代码,这个接口当时是patch了transformers库中的attention_interface接口,这个接口原来的输入是1NTD,输出是1TND,在NPU的分支中,FA都是用tnd计算的,性能较差,而且为了保证接口的一致性,会插入一些额外的转置,这里进行了重构 2. 修正apply_transformers_attention_patch中的笔误 ## Does this PR introduce any user-facing change? 对使用者来说没有影响,不直接暴露这个接口的用法, 对开发者来说,如果是直接迁移transformers的代码,input_layout字段会默认设置为1NTD和transformers保持一致 如果在某些场景下,追求卓越的性能,可以修改输入的qkv的layout,并将转置之后输入attention_interface接口的layout传入input_layout字段 具体的内部接口变更说明如下: - 【新增】input_layout(默认为1NTD,和transformers默认的保持一致):输入的layout,可以在外部的rope或者其他模块中,少做一次transpose,传入不同的layout,目前是支持BNSD,BSND,1TND和1NTD四种,其他的layout会直接raise error - 【等效】cu_seq_lens_q & cu_seq_lens_k:packing形式下需要传入不同的序列是如何打包的,这个其实本来就已经是transformers原生接口的入参了,只是被放在kwargs里面了,这里将其展开,因为我会根据cu_seq_lens_q和kv的长度判断,是否可以退化为bnsd or bsnd - 【变更】ring_fa_layout变更为ring_in_bnsd(默认为False):因为mindspeed core提供的do_ring_attention的接口输入的要求是sbh或者tnd,如果要用bnsd计算的话,输入也得是sbh,所以这里其实很容易引起歧义,同时因为新增了输入的layout,如果支持所有的ring_fa_layout将会有很多的分支需要写,但是其实大部分没有必要支持,比如说输入是bnsd,那就没必要适配tnd的ring了,如果输入时tnd,那也没必要适配sbh和bnsd的ring了。所以我在代码中会自动根据输入的layout判断使用哪个ring layout。ring_in_bnsd表示在do_ring_attention内部,是否需要使用bnsd来计算fa(增加转置,但是长序列下性能会更好) 映射关系如下:(注意,仅在npu的flash_attention_2分支下生效,gpu和sdpa在不支持的操作下,会显式raise error) | input_layout<br>(默认为1NTD) | cu_seqlens | 实际计算layout | CP分支逻辑 | 输出 | | :--- | :--- | :--- | :--- | :--- | | 1NTD | None <br> or len(cu_seqlens) <=2 <br> (ring会结合seq_split_seqlens是否为none判断) | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | 转置为bsnd输出 | | 1NTD | len(cu_seqlens) >=2 <br> (ring会结合seq_split_seqlens是否为none判断) | TND(所以需要先转置+squeeze成TND) | Ulysses:tnd<br>ring/hybrid:TND | unsqueeze为bsnd输出 | | 1TND | None <br> or len(cu_seqlens) <=2 <br> (ring会结合seq_split_seqlens是否为none判断) | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | bsnd | | 1TND | len(cu_seqlens) >=2 <br> (ring会结合seq_split_seqlens是否为none判断) | TND(需要先squeeze成TND) | Ulysses:tnd<br>ring/hybrid:TND | unsqueeze为bsnd输出 | | BSND | / | BSND | ulysses:bsnd<br>ring / hybrid: 转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bsnd | bsnd | | BNSD | / | BNSD | ulysses:bnsd<br>ring / hybrid:转置为sbh输入do_ring_attention接口,内部是否用bnsd算看ring_fa_layout的参数,如果为True,则用bnsd,如果为false则为默认的sbh,注意输出转置为bnsd | bnsd | ## How was this patch tested? 已包含UT:ut/fsdp/ops/flash_attn.py 存量UT整改: 删除ut中对NON_MEGATRON环境变量的设置,目前是发现所有的UT会共用一个调用栈,如果设置该环境变量,会导致MEGATRON后端的UT进入mindspeed_mm/fsdp/ops的patch中,对看护的范围有影响,所以删除该设置。 ut/models/common/test_attention.py 删除,当前该UT是复用了其他UT创建的通信组,如果添加新的UT,会改变UT执行的顺序,导致该通信组不可用,pytest卡死,尝试了一些修改均不可用,考虑到这部分代码目前几乎没有人使用了,FSDP2分支的Flashattention看护会更重要,所以先把这个UT下掉了 已长跑验证qwen3vl和qwen3.5的精度 遗留:packing场景 + causal mask + ring attention 当前未跑通,因为存量模型不涉及这个case,因此当前先直接raise error See merge request: Ascend/MindSpeed-MM!25775 天前