文件最后提交记录最后更新时间
feat: torch fully_shard patch, optimize interaction between fully_shard and checkpoint_wrapper Co-authored-by: liyingxuan<liyingxuan3@huawei.com> # message auto-generated for no-merge-commit merge: !2290 merge master into master feat: torch fully_shard patch, optimize interaction between fully_shard and checkpoint_wrapper Created-by: liyx616 Commit-by: liyingxuan Merged-by: ascend-robot Description: 之前关闭的pr(冲突太多):https://gitcode.com/Ascend/MindSpeed-MM/pull/2289,已按照老pr中的检视意见修改 ## What this PR does / why we need it? torch原生的fully_shard如果被包在重计算内部的时候,会在反向重计算的时候触发 pre_forward和pre_backward,重复unshard参数,并且没有及时释放,会导致性能和显存的劣化 通过給torch原生的fully_shard打patch,对这个缺陷进行修复。 修复的设计思路如下: 1. fully_shard添加hook_module入参,该入参表示fsdp2 做的那些pre_forward, post_forward, pre_backward, post_backward这些hook,都添加到这个hook_module上,这个hook_module一般重计算是设哪个,就设哪个,确保fully_shard的hook可以添加在重计算的外面 2. 原来全局只有一个comm_ctx, 导致所有的allgather, reduce scatter都要等上一个子模块的copyout和chunkcat执行完成,但是copyout和chunkcat是在计算流上执行的,容易导致通信被阻塞。当前全局设置了多个comm_ctx(id会自动推导,无需自己设定),同属于一个hook_module的参数allgather和reduce scatter无需相互等待。观测到hook_module变化时,会等待上一个hook_module所有comm_ctx中的事件结束 ## Does this PR introduce any user-facing change? 修改了torch原生的fully_shard使用方法。 场景1:不需要对layer内部进行细粒度切分,或者不开启重计算,按照原来的写法即可 ```python for layer in model.layers: fully_shard(layer, **fsdp_kwargs) fully_shard(model, **fsdp_kwargs) ``` 场景2:需要对layer内部进行细粒度切分,并且需要使用重计算 ```python for i, layer in enumerate(model.layers): model.layers[i] = checkpoint_wrapper(layer) # 两个子模块可以有不同的devicemesh,或者其他fsdp kwargs,支持灵活配置 fully_shard(layer.attn, hook_module=layer, **fsdp_kwargs1) fully_shard(layer.mlp, hook_module=layer, **fsdp_kwargs2) fully_shard(layer, hook_module=layer, **fsdp_kwargs) fully_shard(model, **fsdp_kwargs) # 如果hook_module不设置的话,默认就是传入的第一个nn.Module ``` ## How was this patch tested? Please explain how to verify the correctness and effectiveness of this feature, as well as its usage constraints and limitations. See merge request: Ascend/MindSpeed-MM!22902 个月前