显存调优
===========================
Last updated: 12/11/2025. Author: LinMingZhe
显存数据调优是大模型训练中不可或缺的一环,模型规模、数据规模不断scale,对应的调优能力不可或缺,此文旨在帮助大家克服各种OOM。
此处所说的“显存”借用了VRAM(Video Random Access Memory)的概念,实际指NPU的HBM(High Bandwidth Memory),有时也称device内存,与host内存相对。
显存数据采集
------------
显存快照
::::::::::::::
显存快照是torch原生的内存管理器的状态日志记录,可以将torch内存池内部各个tensor和segment的生命周期同运行堆栈相关联,还可以查看显存管理器管理的其他信息。
昇腾支持显存快照的功能,对一般性的显存调优,显存快照非常有用。
内置模型显存快照
^^^^^^^^^^^^^^
套件集成了昇腾内存快照采集工具,以提供对模型运行情况的分析。内置模型均已适配,只需修改 ``mindspeed_mm/tools/tools.json`` 文件即可生效。
对复用训练流程的模型,同样仅需修改配置。支持的配置项如下。
.. code:: json5
{
"memory_profile": {
"enable": false, // 内存采集功能开关
"start_step": 0, // 开始录制的步数。数值为训练步数的起始点,0代表初始化过程
"end_step": 2, // 结束录制的步数。数值为训练步数的起始点,0代表初始化过程
"save_path": "./memory_snapshot", // 快照文件保存路径
"dump_ranks": [ // 录制快照的rank列表,从0开始
0
],
"stacks": "all", // 堆栈信息录制。可选项:python/all
"max_entries": null // 最大记录数,null则无限制
}
}
自定义训练流程显存快照
^^^^^^^^^^^^^^^^^^
对独立的训练流程,可参考下列代码,对训练脚本做适配以使用profiler特性。参数配置同上。
.. code:: python
from megatron.training import get_args
from mindspeed_mm.tools.mem_profiler import memory_profiler
args = get_args() # 获取配置
memory_profiler.reset(args.mm.tool.memory_profile) # 使用配置刷新profiler状态
training_preparation() # 运行训练准备代码
while iteration < args.train_iters: # 训练主循环
memory_profiler.step() # 调用profiler记录一个迭代
train_one_step() # 训练一个迭代
memory_profiler.stop() # 停止采集
任意脚本显存快照
^^^^^^^^^^^^^^^^^^
对于不具备典型训练结构的脚本,或者局部的手动调试,可直接调用基础函数。
.. code:: python
code_not_record()
from mindspeed_mm.tools.mem_profiler import _record
_record()
code_to_record()
dump与开始录制可以在不同文件内。
.. code:: python
code_to_record()
from mindspeed_mm.tools.mem_profiler import _dump, _stop
_dump()
_stop()
OOM显存快照
^^^^^^^^^^^^^
如果在训练中遇到了OOM,此时程序很可能在dump前退出。通过以下环境变量,可以采集全程的显存快照。
.. code-block:: shell
export OOM_SNAPSHOT_ENABLE=1 # OOM快照生成开关
export OOM_SNAPSHOT_PATH="/home/usr/" # 保存路径
内存快照使用
^^^^^^^^^^^^^
dump执行完成后,会在输出目录生成 ``snapshot_`` 开头的 ``pickle`` 文件,可以在 `torch页面 <https://pytorch.org/memory_viz>`_ 可视化查看内存快照。
结合下面介绍的内存基本知识和分析手段,就可以进行内存调优了。
显存分析
------------
显存构成和特性
:::::::::::::
显存分析的前提是了解一个模型训练的过程的计算,和计算所必要的显存空间占用,分为时间维度和空间维度。
下面以一个典型的多模态模型训练过程为例拆解显存的构成。
- 静态显存:生命周期可能贯穿或等价于贯穿多个训练step的tensor占用的显存
- 模型权重:在模型初始化时创建,默认情况下会在内存中长期存在。包含encoder权重、transformer权重等。
- 模型梯度:在模型反向计算时创建,每一个参与计算的参数都会生成一个对应的梯度。优化器优化完成后可清空。由于部分后端会为所有模型权重创建梯度buffer并一直持有,因此归入静态显存。
- 优化器状态:对于常见的Adam优化器,会为每一个可训练参数创建两个状态参数。
- 动态显存:在每个step中创建并且生命周期显著短于step的tensor占用的显存。
- 输入数据:训练模型学习的多模态数据。
- encode阶段局部变量:在encode过程中创建,生成latent后所有的局部变量均可释放。(encoder不参与训练时)
- transformer前向阶段局部变量:在前向计算过程中创建,部分计算不会保存局部变量,申请后及时释放,一部分计算会保存局部变量用于反向传播计算。
- transformer反向阶段局部变量:消耗前向保存的局部变量,反向过程产生的其他局部变量及时释放。
- 优化器计算阶段局部变量:消耗反向产生的梯度,产生的激活值不明显。
上面的拆解在具体的模型的快照中可以找到对应的类型。
但是每个模型和训练任务都会有所不同,在应用了不同的优化和分布式特性后,相同的模型的构成也会产生变化,可以结合快照进行分析。
分析的思路就是找到每一类tensor在device上的生命周期和实际参与的计算,有大体了解,进而进一步分析。
瓶颈分析
:::::::::::::
内存占用存在短板效应,训练过程中的一个点超出了设备资源上限,就会导致整个训练过程都不可用。
因此分析瓶颈是显存调优的第一步,也就是找到显存峰值时的内存分配状态,从此出发分析调优空间。
分析瓶颈大约有下面一些步骤。
- 找到显存峰值、尖刺,作为重点观察点。
- 对每一个观察点,分析局部峰值产生时,内存的分布情况,静态显存的构成、动态显存的构成。
- 筛选显存占用较大的一类tensor,或者生命周期明显过长的一类tensor。
- 对每一类tensor,找到一个代表,分析其大小、申请时机、释放时机,是否符合预期。
- 对于不符合预期的情况,分析tensor创建和引用的情况,比如堆栈引用、容器引用等,防止生命周期过长。
- 对于符合预期的情况,分析优先应用哪些调优方法,参考下一章。
调优思想
:::::::::::::
在显存调优中,调优思想是共通的,下面的具体手段基本都在应用下面的几个原则。
- **将计算切分到不同卡**:将不相互依赖的计算分散到不同卡,其依赖的tensor也可以根据计算的分布分散到不同卡上。代表:DP、PP、TP、SP等。
- **卸载用不到的tensor**:将当前计算无关的tensor进行卸载。用到时再装载。代表:激活值卸载、FSDP offload、encoder offload。
- **串行计算**:对部分计算串行化,降低一次计算所依赖的tensor规模。代表:梯度累积、分块loss计算、VAE分块处理。
- **重计算**:直接释放计算中间结果,仅保留少量数据,在需要时重新计算出来。代表:重计算。
显存调优方法
------------
本章将会讲述MM显存调优方法,着眼于调优思路,内置模型涉及的具体特性的使用方法和约束,优先以模型readme和对应特性文档为准。
显存调优也是一种平衡术,获得显存收益往往伴随着开销,实际调优过程需要权衡利弊综合考量,抓住主要矛盾。
对于新模型开发,可以根据实际情况优先应用容易适配和必要的调优手段。
.. contents::
:local:
数据并行(DP)
::::::
动态显存和MBS直接相关,因此通过扩大NPU数量,将多个数据分散到不同的卡上执行计算,就可以降低动态显存占用。
同时,可以利用计算过程对静态显存的依赖性,细粒度管理静态显存,在需要时做聚合等操作,可以降低静态现存占用。
两者相结合,就产生了DDP、FSDP,对应zero1、zero3。
DDP
^^^
套件的DDP使用megatron版本的分布式优化器。可以对静态显存中优化器状态进行切分。
.. code-block:: shell
--use-distributed-optimizer
FSDP
^^^^
套件支持FSDP2特性,可以参考 :doc:`FSDP2使用说明 <docs/zh/features/fsdp2.rst>` 使用FSDP2。
启用此特性,默认情况下可以对静态显存做完全切分。
静态显存调优
::::::::::
FSDP offloading
^^^^^^^^^^^^^^^^^
显存调优手段中,offload十分常见,后面将多次提到,前面FSDP实际就包含了offload的高级特性,可以通过配置开启。
在内部模型中,直接修改yaml文件即可生效,将参数,梯度和优化器状态卸载到CPU内存。
.. code-block:: yaml
offload_to_cpu: True
.. note::
当设置 ``offload_to_cpu=True`` 时,需在入口脚本中设置通信组为双后端,即: ``--distributed-backend npu:hccl,cpu:gloo``
具体特性参考 `官方文档 <https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.CPUOffloadPolicy>`_
Encoder Offload
^^^^^^^^^^^^^^^
对于生成模型,encoder往往不参与训练,因此可以利用参数的时间局部性,在主干模型计算时offload权重,并利用缓存来摊薄搬运开销。
MM内置模型中,通过配置卸载周期来启用,适用于基于SoraModel的模型,如wan2.2。
.. code:: json
"encoder_offload_interval": 8,
- 收益:非encoder阶段,从静态显存中释放encoder权重
- 代价:增加了权重搬运开销和缓存管理开销
模型并行
^^^^^^
TP(tensor parallel)、PP(pipeline parallel)是常见的模型并行策略,直接将权重切分到不同的device,可以降低静态显存占用,通过通信等操作实现算法等价。
在PP的基础上,有VPP等算法降低PP算法产生的空泡,提升计算效率。
MM内置的模型如qwen2.5VL、OpenSoraPlan1.3支持开启模型切分,详情见相应模型readme和特性文档。
- TP: 张量模型并行
- 收益:切分模型,降低单卡计算量
- 代价:通信开销较大,每一个MLP/Attention均需要通信激活值
- PP:流水线并行
- 收益:切分模型,降低激活值,多卡分摊计算量
- 代价:每两个PP阶段之间存在激活值通信,并且存在空泡,损失计算性能
专家并行
^^^^^^
专家并行是在MOE场景下,对MOE的专家进行切分,属于一种特殊的模型并行。
- 收益:显存收益,一个rank上只存放部分权重。
- 代价:性能损耗,有内存搬运开销和激活值通信开销。同时还会有专家均衡性的挑战。
动态显存调优
::::::::::
重计算
^^^^^
在生成模型的DiT、理解模型的decoder等主干transformer模型前反向计算阶段,激活值中间结果的保存是一个常见的瓶颈。
激活值重计算可以仅保存少量保存点,并在反向计算时从保存点再次计算出中间变量用于反向计算,避免了保存大部分中间结果到反向,可以大幅降低激活值。
对于FSDP2模型,重计算在yaml中配置生效,使用方法参考 :doc:`FSDP2使用说明 <docs/zh/features/fsdp2.rst>` 使用
- ``recompute_modules``
- 描述:配置激活值重计算,以计算换内存
- 配置格式:与 `sub_modules_to_wrap` 一致
- 约束:与megatron中的完全重计算功能存在冲突,需将其关闭
对于非FSDP2模型,重计算可通过args配置开启。
.. code-block:: shell
GPT_ARGS="
--recompute-granularity full \
--recompute-method block \
--recompute-num-layers 40 \
"
- 收益:显著降低主干模型前反向计算动态显存峰值。
- 代价:增加了计算开销,完全重计算场景下,前向计算量翻倍。
梯度累积
^^^^^^^
micro batch size扩大时,一个step可以训练多个数据,激活值占用会相应加大。
梯度累积将一个global batch分成多个micro batches,一个step只处理一个micro batch,并将梯度累积起来。
可以实现使用较小的MBS等效于较大的MBS的单步训练数据量。
可在启动脚本参数中改变GRAD_ACC_STEP、MBS控制。
.. code-block:: shell
MBS=1
GRAD_ACC_STEP=1
DP=$(($WORLD_SIZE/$TP/$PP/$CP))
GBS=$(($MBS*$GRAD_ACC_STEP*$DP))
# ...
--micro-batch-size ${MBS} \
--global-batch-size ${GBS}
- 收益:降低激活值显存占用。
- 代价:一次下发的计算量变小,在FSDP场景下引入更多通信。
序列并行:TP-SP/Ulysses/ring/USP
^^^^^^^
序列是transformer中重要的概念,序列维度在注意力计算中会相互依赖,在MLP的token层级计算,或element层级的计算中,相互不依赖。
利用这样的局部性,业界设计了多种序列并行算法,MM对主流的算法进行了适配。
- TP-SP
- 收益:切分了MLP/attention模块之间,如LayerNorm和Dropout等计算过程的序列,可降低切分阶段的动态显存占用,减少冗余计算。
- 代价:必须先启用TP。
- Ulysses序列并行
- 收益:按照Attention计算中的头、其他计算的序列维度,在切分阶段分摊计算和显存激活值到不同卡上,降低动态显存占用。
- 代价:增加通信量,CP大小必须被num head整除。
- RingAttention序列并行
- 收益:持有部分Q,环状传递部分KV并计算,在切分阶段分摊动态显存和计算量。
- 代价:增加通信量。
- USP混合序列并行(Ulysses + RingAttention)
- 收益:在切分阶段分摊动态现存和计算量,是Ulysses的高性能和ring的自由切分的折中。
- 代价:增加通信量。
多种序列并行算法可以根据实际情况选用,并和其他切分策略组合使用,考虑到实际吞吐等因素,具体情况会比较复杂。
已经有一些研究提供了一些建议,下面引述 `USP论文 <https://arxiv.org/pdf/2405.07719>`_ 中的一些观点,详细分析见论文原文。
- 建议使用 Unified-SP 替代 SP-Ring 和 SP-Ulysses,因为它兼具两者功能,并带来额外优势
- 建议优先使用 DP,仅在批次大小(bs)不足以切分时,再考虑是否启用 SP。
- 建议在使用 SP 时,务必搭配 ZeRO-1/2;也可进一步启用 ZeRO-3 与 Offload 技术,以通信开销换取内存节省。
- 在大规模场景下,SP 的通信开销优于 TP-sp;启用 GQA 可进一步降低 SP 的通信成本。
- 将 TP-SP 切换为 SP 并不会提升训练时的序列长度;SP+ZeRO3 可实现与 TP-sp 相近的序列长度。
- 建议采用更高程度的 SP 并行:当注意力头数受限时,可设置较大的 ring 维度,把长序列拆到更多计算设备上训练——这是 TP-sp 无法实现的独特优势。
VAE序列并行
^^^^^^^^^^^
在生成模型的Encoder阶段,主要的动态显存瓶颈在于VAE对视频数据的编码,尤其在处理高分辨率和帧数的视频时。
对于主流的VAE而言,可以利用conv的计算局部性,结合通信算法,使用多卡分担一条数据的计算。
但是由于各个模型的实现不同,具体算法也有所区别。MM内置的OpenSoraPlan1.3和CogVideoX等模型,适配了切分算法,使能方式和约束见readme说明。
- 收益:降低VAE encode阶段动态显存瓶颈。
- 代价:增加通信。部分算法受帧数等实际参数约束。
VAE分块处理
^^^^^^^^^^
解决VAE激活值瓶颈的另一个方法是开启分块处理。业界当前已经有比较成熟的分块处理算法。
VAE 分块处理通过将图像分割成多个较小的重叠图块,一次处理一块,而不是一次性处理整张图像,从而节省内存。
.. note::
分块处理(tiling)不会显著降低生成质量,但是和原始计算并不等价
MM内置模型可以在model.json中开启分块处理。
.. code-block:: json5
{
"ae": {
"enable_tiling": true # 部分模型为"use_tiling"
}
}
- 收益:显著降低VAE encode阶段动态显存瓶颈。
- 代价:计算不等价,下发的计算小而多,且可能包含重叠分块,计算耗时往往会增加
激活值卸载
^^^^^^^^
激活值卸载也是常用的做法,一般会把重计算的中间结果,或者cache等生命周期相对较长的激活值进行offload。
最简单的,可以直接使用to来实现offload。
.. code-block:: python
for cache in caches:
for k in cache:
cache[k] = cache[k].to("cpu")
for i in range(len(caches)):
for k in caches[i]:
caches[i] = caches[i].to("npu")
run_your_code(**caches[i])
我们对一些场景的offload做了异步优化,可以参考async offload特性文档。
.. note::
需要注意的是,torch的tensor生命周期管理复用了python的引用管理,因此在offload时需要确认卸载tensor的引用已清空,否则空间将不会被回收。
loss分块计算
^^^^^^^^^^
在典型自回归语言模型中,loss计算前通过LM Head将每个token映射到词表大小,然后通过softmax计算出每个词的概率,用于和真实结果计算出loss数值。
此计算过程会产生大量激活值,尤其在词表、序列长度都很大的情况下。
MM对LM Head应用列切,并对激活值使用动态softmax等算法,计算loss。
用户还可以使用chunkloss特性,利用序列loss互不依赖的特点,将序列维度切块计算loss并累积,进一步降低显存占用。
- 对词表维度的切分。
- 收益:降低激活值显存占用,TP越大收益越大。
- 代价:仅在TP启用时有效。
- 对序列维度分块计算累积loss。
- 收益:显著降低loss计算阶段激活值显存占用。
- 代价:下发计算变碎,降低硬件利用率。
参考资料
------------