| feat: torch fully_shard patch, optimize interaction between fully_shard and checkpoint_wrapper
Co-authored-by: liyingxuan<liyingxuan3@huawei.com>
# message auto-generated for no-merge-commit merge:
!2290 merge master into master
feat: torch fully_shard patch, optimize interaction between fully_shard and checkpoint_wrapper
Created-by: liyx616
Commit-by: liyingxuan
Merged-by: ascend-robot
Description: 之前关闭的pr(冲突太多):https://gitcode.com/Ascend/MindSpeed-MM/pull/2289,已按照老pr中的检视意见修改
## What this PR does / why we need it?
torch原生的fully_shard如果被包在重计算内部的时候,会在反向重计算的时候触发 pre_forward和pre_backward,重复unshard参数,并且没有及时释放,会导致性能和显存的劣化
通过給torch原生的fully_shard打patch,对这个缺陷进行修复。
修复的设计思路如下:
1. fully_shard添加hook_module入参,该入参表示fsdp2 做的那些pre_forward, post_forward, pre_backward, post_backward这些hook,都添加到这个hook_module上,这个hook_module一般重计算是设哪个,就设哪个,确保fully_shard的hook可以添加在重计算的外面
2. 原来全局只有一个comm_ctx, 导致所有的allgather, reduce scatter都要等上一个子模块的copyout和chunkcat执行完成,但是copyout和chunkcat是在计算流上执行的,容易导致通信被阻塞。当前全局设置了多个comm_ctx(id会自动推导,无需自己设定),同属于一个hook_module的参数allgather和reduce scatter无需相互等待。观测到hook_module变化时,会等待上一个hook_module所有comm_ctx中的事件结束
## Does this PR introduce any user-facing change?
修改了torch原生的fully_shard使用方法。
场景1:不需要对layer内部进行细粒度切分,或者不开启重计算,按照原来的写法即可
```python
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
```
场景2:需要对layer内部进行细粒度切分,并且需要使用重计算
```python
for i, layer in enumerate(model.layers):
model.layers[i] = checkpoint_wrapper(layer)
# 两个子模块可以有不同的devicemesh,或者其他fsdp kwargs,支持灵活配置
fully_shard(layer.attn, hook_module=layer, **fsdp_kwargs1)
fully_shard(layer.mlp, hook_module=layer, **fsdp_kwargs2)
fully_shard(layer, hook_module=layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs) # 如果hook_module不设置的话,默认就是传入的第一个nn.Module
```
## How was this patch tested?
Please explain how to verify the correctness and effectiveness of this feature, as well as its usage constraints and limitations.
See merge request: Ascend/MindSpeed-MM!2290 | 2 个月前 |