Norm重计算

背景与挑战

大模型训练过程中,往往会面临的显存不足的问题。

解决方案

类似于激活函数重计算,本特性支持了Norm层的重计算,运用激活函数重计算特性中的checkpoint机制,对Norm层进行重计算处理。 具体细节可参见文献Accelerating the Training of Large Language Models using Efficient Activation Rematerialization and Optimal Hybrid Parallelism

使用场景

主要用于训练场景,用户内存不足或要进一步节省内存时。

使用方法

需在训练脚本中加入以下参数配置。

--recompute-norm # 开启Norm重计算 --recompute-norm-num-layers ${num} # num表示Norm重计算的层数

说明

  • Norm重计算特性仅支持mcore分支,不支持legacy分支,即仅支持在开启--use-mcore-models时,通过--recompute-norm使能。

  • Norm重计算兼容激活函数重计算、全重计算同时开启:

    • 同时开启时,仅支持--recompute-method设置为block。
    • 同时开启时,将按照指定的全重计算和Norm重计算的层数做各自类型的重计算,即不会有一层既做全重计算又做Norm重计算。
  • 执行优先级是先计算全重计算层,后Norm重计算层。

使用效果

开启后可节省RMSNorm/LayerNorm层的输出激活内存,并且由于Norm计算速度较快,重计算后对整体性能影响较小。对于开启TP及SP的场景,由于该激活内存在TP域内已进行切分,开启后效果不明显,对于未使用TP及SP的模型可考虑使用。