Async Activation Offload

背景与挑战

随着大模型参数规模的增长和序列长度的上升,训练过程中对显存的需求急剧上升。目前对激活值显存的优化方案主要依赖于重计算技术和序列并行技术。这些技术存在以下瓶颈:

  • 重计算通过丢弃前向传播中的激活值,并在反向传播时重新计算来节省显存,带来了大量的冗余计算。
  • 序列并行虽然能将单个序列的计算分配到多个设备上达到降低显存的目的,但是频繁的跨设备通信可能难以被有效掩盖

针对上述挑战,可以使用异步激活值卸载(Async Activation Offload)策略

解决方案

  • 显存优化:将激活值张量从device侧卸载至host侧,显著降低峰值显存占用
  • 异步执行:利用多流机制实现卸载(D2H)和加载(H2D)的异步,使拷贝过程被计算掩盖
  • 提前预取:反向过程中,通过prefetch机制提前加载后续需要的张量,隐藏加载延迟

使用方法

该特性支持按“块”(block)组织张量生命周期,灵活管理模型不同block的激活值。使用示例如下:

with async_save_on_cpu(
    h2d_stream=h2d_stream,
    d2h_stream=d2h_stream,
    block_idx=block_idx,
    depth=depth,
    custom_check_fn=your_check_fn
):
    # 模型某个block的前向计算代码,此处仅作为示例
    output = layer(input)

参数详解

  • h2d_stream/d2h_stream:H2D和D2H流,建议全局单独新建一条流单独用来执行H2D和D2H任务,实现和计算流异步的效果
  • block_idx:当前block在模型中的编号
  • depth:模型的总层数
  • custom_check_fn:自定义校验函数,只有校验之后返回True的激活值才会被offload,建议根据实际情况筛选出计算量大,激活值参数量小的部分,并结合重计算策略,激活值参数量大,计算耗时短的进行重计算,激活值参数量小,计算耗时长的进行offload。否则H2D和D2H的开销过大,难以被计算掩盖。

使用案例及效果

  • 多模态模型长序列场景:self attention的计算量随序列长度呈平方关系增长,使用该方案卸载self attention前向计算的激活值,并在重计算时跳过self attention的重计算。典型场景下端到端性能收益20%以上。
  • FSDP2场景:FSDP2分布式策略下,对模型参数进行切分和聚合,较短序列长度下,计算耗时无法掩盖通信耗时。可以使用该方案,将重计算入口的激活值卸载,节省出显存后增大micro batch size或序列长度提高计算比例。典型场景下端到端性能收益60%以上。