Mamba-CP
背景
Mamba为解决transformer模型序列长度2次方复杂度提出,成为长序列训练的重要架构,在序列长度大幅增长时,激活值对显存压力大幅增长,仍然急需CP大幅降低超长序列带来的显存压力,当前外部Mamba开源框架CP仍处于空白;
问题
在Mamba的SSM递归运算步骤中存在时间依赖关系,传统CP必须等上一CP rank运算完毕将结果传递到下一CP rank方可执行下一步运算,引入空闲等待,设计了一种并行Mamba-CP方案使得所有rank可以并发执行状态传递计算,相对传统CP性能大幅提升;
传统CP见Mamba-2 paper Figure 5
解决方案
针对存在时间依赖关系的状态传递部分,对各个CP rank中local_decay及local_state进行AllGather,使所有CP rank可以并发执行状态传递计算,同时还对前向的AllGather和反向的ReduceScatter进行了计算通信掩盖;
使用场景
- 与TP和SP正交,可以在开启TP的基础上进一步开启CP降低显存;
- TP有n_groups整除限制,CP无限制;
- CP在显存不足场景,开启CP相比开启重计算降低显存方式性能更优;
使用方法
| 重要参数 | 参数说明 |
|---|---|
| --context-parallel-algo mamba_cp_algo | 长序列并行算法选项,默认项为ulysses_cp_algo,当设置为mamba_cp_algo时开启Mamba-CP。 |
| --context-parallel-size [int] | 开启CP对应的数量,默认为1,根据用户需求配置。 |
使用效果
节省显存方式:重计算、额外开CP等;重计算为常见方式,但会引入额外30%耗时,在长序列、显存受限场景通常寻求如何去组合特性,在不超出硬件显存的前提下尽可能提升性能,在相同显存占用基础上,对开启CP节省显存和开启重计算方式节省显存进行了性能对比如下:
CP开启前后显存优化及性能变化:
| 序列长度 | 并行配置 | 显存占用 | 内存优化 | 性能 | 性能变化 |
|---|---|---|---|---|---|
| 32K | TP4CP1 | 56129MB | —— | 3761.1ms | —— |
| 32K | TP4CP2 | 32613MB | 42% | 3862.3ms | -2.69% |
CP与重计算缩减显存+性能对比:
| 序列长度 | 并行配置 | 显存占用 | 性能 | 加速比例 |
|---|---|---|---|---|
| 32K | TP4CP1 + 全重计算 | 同等显存30G | 4728.8ms | —— |
| 32K | TP4CP2 | 同等显存30G | 3862.3ms | +22.43% |
注意事项:
- 在面临Mamba-CP场景需要省显存情况下,优先开启CP,然后再开重计算;