Mamba-CP

背景

Mamba为解决transformer模型序列长度2次方复杂度提出,成为长序列训练的重要架构,在序列长度大幅增长时,激活值对显存压力大幅增长,仍然急需CP大幅降低超长序列带来的显存压力,当前外部Mamba开源框架CP仍处于空白;

问题

在Mamba的SSM递归运算步骤中存在时间依赖关系,传统CP必须等上一CP rank运算完毕将结果传递到下一CP rank方可执行下一步运算,引入空闲等待,设计了一种并行Mamba-CP方案使得所有rank可以并发执行状态传递计算,相对传统CP性能大幅提升;

传统CP见Mamba-2 paper Figure 5

解决方案

针对存在时间依赖关系的状态传递部分,对各个CP rank中local_decay及local_state进行AllGather,使所有CP rank可以并发执行状态传递计算,同时还对前向的AllGather和反向的ReduceScatter进行了计算通信掩盖;

使用场景

  1. 与TP和SP正交,可以在开启TP的基础上进一步开启CP降低显存;
  2. TP有n_groups整除限制,CP无限制;
  3. CP在显存不足场景,开启CP相比开启重计算降低显存方式性能更优;

使用方法

重要参数 参数说明
--context-parallel-algo mamba_cp_algo 长序列并行算法选项,默认项为ulysses_cp_algo,当设置为mamba_cp_algo时开启Mamba-CP。
--context-parallel-size [int] 开启CP对应的数量,默认为1,根据用户需求配置。

使用效果

节省显存方式:重计算、额外开CP等;重计算为常见方式,但会引入额外30%耗时,在长序列、显存受限场景通常寻求如何去组合特性,在不超出硬件显存的前提下尽可能提升性能,在相同显存占用基础上,对开启CP节省显存和开启重计算方式节省显存进行了性能对比如下:

CP开启前后显存优化及性能变化:

序列长度 并行配置 显存占用 内存优化 性能 性能变化
32K TP4CP1 56129MB —— 3761.1ms ——
32K TP4CP2 32613MB 42% 3862.3ms -2.69%

CP与重计算缩减显存+性能对比:

序列长度 并行配置 显存占用 性能 加速比例
32K TP4CP1 + 全重计算 同等显存30G 4728.8ms ——
32K TP4CP2 同等显存30G 3862.3ms +22.43%

注意事项:

  1. 在面临Mamba-CP场景需要省显存情况下,优先开启CP,然后再开重计算;