显存调优
===========================

Last updated: 12/11/2025. Author: LinMingZhe

显存数据调优是大模型训练中不可或缺的一环,模型规模、数据规模不断scale,对应的调优能力不可或缺,此文旨在帮助大家克服各种OOM。

此处所说的“显存”借用了VRAM(Video Random Access Memory)的概念,实际指NPU的HBM(High Bandwidth Memory),有时也称device内存,与host内存相对。

显存数据采集
------------

显存快照
::::::::::::::

显存快照是torch原生的内存管理器的状态日志记录,可以将torch内存池内部各个tensor和segment的生命周期同运行堆栈相关联,还可以查看显存管理器管理的其他信息。
昇腾支持显存快照的功能,对一般性的显存调优,显存快照非常有用。

内置模型显存快照
^^^^^^^^^^^^^^

套件集成了昇腾内存快照采集工具,以提供对模型运行情况的分析。内置模型均已适配,只需修改 ``mindspeed_mm/tools/tools.json`` 文件即可生效。

对复用训练流程的模型,同样仅需修改配置。支持的配置项如下。

.. code:: json5

    {
      "memory_profile": {
        "enable": false,    // 内存采集功能开关
        "start_step": 0,    // 开始录制的步数。数值为训练步数的起始点,0代表初始化过程
        "end_step": 2,      // 结束录制的步数。数值为训练步数的起始点,0代表初始化过程
        "save_path": "./memory_snapshot",  // 快照文件保存路径
        "dump_ranks": [     // 录制快照的rank列表,从0开始
          0
        ],
        "stacks": "all",    // 堆栈信息录制。可选项:python/all
        "max_entries": null // 最大记录数,null则无限制
      }
    }

自定义训练流程显存快照
^^^^^^^^^^^^^^^^^^

对独立的训练流程,可参考下列代码,对训练脚本做适配以使用profiler特性。参数配置同上。

.. code:: python

    from megatron.training import get_args
    from mindspeed_mm.tools.mem_profiler import memory_profiler

    args = get_args()                                   # 获取配置
    memory_profiler.reset(args.mm.tool.memory_profile)  # 使用配置刷新profiler状态
    training_preparation()                              # 运行训练准备代码
    while iteration < args.train_iters:                 # 训练主循环
        memory_profiler.step()                          # 调用profiler记录一个迭代
        train_one_step()                                # 训练一个迭代
    memory_profiler.stop()                              # 停止采集

任意脚本显存快照
^^^^^^^^^^^^^^^^^^

对于不具备典型训练结构的脚本,或者局部的手动调试,可直接调用基础函数。

.. code:: python

    code_not_record()
    from mindspeed_mm.tools.mem_profiler import _record
    _record()
    code_to_record()

dump与开始录制可以在不同文件内。

.. code:: python

    code_to_record()
    from mindspeed_mm.tools.mem_profiler import _dump, _stop
    _dump()
    _stop()

OOM显存快照
^^^^^^^^^^^^^

如果在训练中遇到了OOM,此时程序很可能在dump前退出。通过以下环境变量,可以采集全程的显存快照。

.. code-block:: shell

    export OOM_SNAPSHOT_ENABLE=1            # OOM快照生成开关
    export OOM_SNAPSHOT_PATH="/home/usr/"   # 保存路径

内存快照使用
^^^^^^^^^^^^^

dump执行完成后,会在输出目录生成 ``snapshot_`` 开头的 ``pickle`` 文件,可以在 `torch页面 <https://pytorch.org/memory_viz>`_ 可视化查看内存快照。
结合下面介绍的内存基本知识和分析手段,就可以进行内存调优了。

显存分析
------------

显存构成和特性
:::::::::::::

显存分析的前提是了解一个模型训练的过程的计算,和计算所必要的显存空间占用,分为时间维度和空间维度。
下面以一个典型的多模态模型训练过程为例拆解显存的构成。

- 静态显存:生命周期可能贯穿或等价于贯穿多个训练step的tensor占用的显存

  - 模型权重:在模型初始化时创建,默认情况下会在内存中长期存在。包含encoder权重、transformer权重等。
  - 模型梯度:在模型反向计算时创建,每一个参与计算的参数都会生成一个对应的梯度。优化器优化完成后可清空。由于部分后端会为所有模型权重创建梯度buffer并一直持有,因此归入静态显存。
  - 优化器状态:对于常见的Adam优化器,会为每一个可训练参数创建两个状态参数。

- 动态显存:在每个step中创建并且生命周期显著短于step的tensor占用的显存。

  - 输入数据:训练模型学习的多模态数据。
  - encode阶段局部变量:在encode过程中创建,生成latent后所有的局部变量均可释放。(encoder不参与训练时)
  - transformer前向阶段局部变量:在前向计算过程中创建,部分计算不会保存局部变量,申请后及时释放,一部分计算会保存局部变量用于反向传播计算。
  - transformer反向阶段局部变量:消耗前向保存的局部变量,反向过程产生的其他局部变量及时释放。
  - 优化器计算阶段局部变量:消耗反向产生的梯度,产生的激活值不明显。

上面的拆解在具体的模型的快照中可以找到对应的类型。
但是每个模型和训练任务都会有所不同,在应用了不同的优化和分布式特性后,相同的模型的构成也会产生变化,可以结合快照进行分析。
分析的思路就是找到每一类tensor在device上的生命周期和实际参与的计算,有大体了解,进而进一步分析。


瓶颈分析
:::::::::::::

内存占用存在短板效应,训练过程中的一个点超出了设备资源上限,就会导致整个训练过程都不可用。
因此分析瓶颈是显存调优的第一步,也就是找到显存峰值时的内存分配状态,从此出发分析调优空间。

分析瓶颈大约有下面一些步骤。

- 找到显存峰值、尖刺,作为重点观察点。
- 对每一个观察点,分析局部峰值产生时,内存的分布情况,静态显存的构成、动态显存的构成。
- 筛选显存占用较大的一类tensor,或者生命周期明显过长的一类tensor。
- 对每一类tensor,找到一个代表,分析其大小、申请时机、释放时机,是否符合预期。
  - 对于不符合预期的情况,分析tensor创建和引用的情况,比如堆栈引用、容器引用等,防止生命周期过长。
  - 对于符合预期的情况,分析优先应用哪些调优方法,参考下一章。


调优思想
:::::::::::::

在显存调优中,调优思想是共通的,下面的具体手段基本都在应用下面的几个原则。

- **将计算切分到不同卡**:将不相互依赖的计算分散到不同卡,其依赖的tensor也可以根据计算的分布分散到不同卡上。代表:DP、PP、TP、SP等。
- **卸载用不到的tensor**:将当前计算无关的tensor进行卸载。用到时再装载。代表:激活值卸载、FSDP offload、encoder offload。
- **串行计算**:对部分计算串行化,降低一次计算所依赖的tensor规模。代表:梯度累积、分块loss计算、VAE分块处理。
- **重计算**:直接释放计算中间结果,仅保留少量数据,在需要时重新计算出来。代表:重计算。


显存调优方法
------------

本章将会讲述MM显存调优方法,着眼于调优思路,内置模型涉及的具体特性的使用方法和约束,优先以模型readme和对应特性文档为准。
显存调优也是一种平衡术,获得显存收益往往伴随着开销,实际调优过程需要权衡利弊综合考量,抓住主要矛盾。
对于新模型开发,可以根据实际情况优先应用容易适配和必要的调优手段。

.. contents::
   :local:

数据并行(DP)
::::::

动态显存和MBS直接相关,因此通过扩大NPU数量,将多个数据分散到不同的卡上执行计算,就可以降低动态显存占用。
同时,可以利用计算过程对静态显存的依赖性,细粒度管理静态显存,在需要时做聚合等操作,可以降低静态现存占用。
两者相结合,就产生了DDP、FSDP,对应zero1、zero3。

DDP
^^^

套件的DDP使用megatron版本的分布式优化器。可以对静态显存中优化器状态进行切分。

.. code-block:: shell

    --use-distributed-optimizer

FSDP
^^^^

套件支持FSDP2特性,可以参考 :doc:`FSDP2使用说明 <docs/zh/features/fsdp2.rst>` 使用FSDP2。
启用此特性,默认情况下可以对静态显存做完全切分。

静态显存调优
::::::::::

FSDP offloading
^^^^^^^^^^^^^^^^^

显存调优手段中,offload十分常见,后面将多次提到,前面FSDP实际就包含了offload的高级特性,可以通过配置开启。

在内部模型中,直接修改yaml文件即可生效,将参数,梯度和优化器状态卸载到CPU内存。

.. code-block:: yaml

    offload_to_cpu: True

.. note::
    当设置 ``offload_to_cpu=True`` 时,需在入口脚本中设置通信组为双后端,即: ``--distributed-backend npu:hccl,cpu:gloo``

具体特性参考 `官方文档 <https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.CPUOffloadPolicy>`_

Encoder Offload
^^^^^^^^^^^^^^^

对于生成模型,encoder往往不参与训练,因此可以利用参数的时间局部性,在主干模型计算时offload权重,并利用缓存来摊薄搬运开销。

MM内置模型中,通过配置卸载周期来启用,适用于基于SoraModel的模型,如wan2.2。

.. code:: json

    "encoder_offload_interval": 8,

- 收益:非encoder阶段,从静态显存中释放encoder权重
- 代价:增加了权重搬运开销和缓存管理开销

模型并行
^^^^^^

TP(tensor parallel)、PP(pipeline parallel)是常见的模型并行策略,直接将权重切分到不同的device,可以降低静态显存占用,通过通信等操作实现算法等价。
在PP的基础上,有VPP等算法降低PP算法产生的空泡,提升计算效率。
MM内置的模型如qwen2.5VL、OpenSoraPlan1.3支持开启模型切分,详情见相应模型readme和特性文档。

- TP: 张量模型并行

  - 收益:切分模型,降低单卡计算量
  - 代价:通信开销较大,每一个MLP/Attention均需要通信激活值

- PP:流水线并行

  - 收益:切分模型,降低激活值,多卡分摊计算量
  - 代价:每两个PP阶段之间存在激活值通信,并且存在空泡,损失计算性能

专家并行
^^^^^^

专家并行是在MOE场景下,对MOE的专家进行切分,属于一种特殊的模型并行。

- 收益:显存收益,一个rank上只存放部分权重。
- 代价:性能损耗,有内存搬运开销和激活值通信开销。同时还会有专家均衡性的挑战。

动态显存调优
::::::::::

重计算
^^^^^

在生成模型的DiT、理解模型的decoder等主干transformer模型前反向计算阶段,激活值中间结果的保存是一个常见的瓶颈。
激活值重计算可以仅保存少量保存点,并在反向计算时从保存点再次计算出中间变量用于反向计算,避免了保存大部分中间结果到反向,可以大幅降低激活值。

对于FSDP2模型,重计算在yaml中配置生效,使用方法参考 :doc:`FSDP2使用说明 <docs/zh/features/fsdp2.rst>` 使用

    - ``recompute_modules``
        - 描述:配置激活值重计算,以计算换内存
        - 配置格式:与 `sub_modules_to_wrap` 一致
        - 约束:与megatron中的完全重计算功能存在冲突,需将其关闭

对于非FSDP2模型,重计算可通过args配置开启。

.. code-block:: shell

      GPT_ARGS="
          --recompute-granularity full \
          --recompute-method block \
          --recompute-num-layers 40 \
      "

- 收益:显著降低主干模型前反向计算动态显存峰值。
- 代价:增加了计算开销,完全重计算场景下,前向计算量翻倍。

梯度累积
^^^^^^^

micro batch size扩大时,一个step可以训练多个数据,激活值占用会相应加大。
梯度累积将一个global batch分成多个micro batches,一个step只处理一个micro batch,并将梯度累积起来。
可以实现使用较小的MBS等效于较大的MBS的单步训练数据量。

可在启动脚本参数中改变GRAD_ACC_STEP、MBS控制。

.. code-block:: shell

    MBS=1
    GRAD_ACC_STEP=1
    DP=$(($WORLD_SIZE/$TP/$PP/$CP))
    GBS=$(($MBS*$GRAD_ACC_STEP*$DP))
    # ...
    --micro-batch-size ${MBS} \
    --global-batch-size ${GBS}

- 收益:降低激活值显存占用。
- 代价:一次下发的计算量变小,在FSDP场景下引入更多通信。

序列并行:TP-SP/Ulysses/ring/USP
^^^^^^^

序列是transformer中重要的概念,序列维度在注意力计算中会相互依赖,在MLP的token层级计算,或element层级的计算中,相互不依赖。
利用这样的局部性,业界设计了多种序列并行算法,MM对主流的算法进行了适配。

- TP-SP

  - 收益:切分了MLP/attention模块之间,如LayerNorm和Dropout等计算过程的序列,可降低切分阶段的动态显存占用,减少冗余计算。
  - 代价:必须先启用TP。

- Ulysses序列并行

  - 收益:按照Attention计算中的头、其他计算的序列维度,在切分阶段分摊计算和显存激活值到不同卡上,降低动态显存占用。
  - 代价:增加通信量,CP大小必须被num head整除。

- RingAttention序列并行

  - 收益:持有部分Q,环状传递部分KV并计算,在切分阶段分摊动态显存和计算量。
  - 代价:增加通信量。

- USP混合序列并行(Ulysses + RingAttention)

  - 收益:在切分阶段分摊动态现存和计算量,是Ulysses的高性能和ring的自由切分的折中。
  - 代价:增加通信量。

多种序列并行算法可以根据实际情况选用,并和其他切分策略组合使用,考虑到实际吞吐等因素,具体情况会比较复杂。
已经有一些研究提供了一些建议,下面引述 `USP论文 <https://arxiv.org/pdf/2405.07719>`_ 中的一些观点,详细分析见论文原文。

    - 建议使用 Unified-SP 替代 SP-Ring 和 SP-Ulysses,因为它兼具两者功能,并带来额外优势
    - 建议优先使用 DP,仅在批次大小(bs)不足以切分时,再考虑是否启用 SP。
    - 建议在使用 SP 时,务必搭配 ZeRO-1/2;也可进一步启用 ZeRO-3 与 Offload 技术,以通信开销换取内存节省。
    - 在大规模场景下,SP 的通信开销优于 TP-sp;启用 GQA 可进一步降低 SP 的通信成本。
    - 将 TP-SP 切换为 SP 并不会提升训练时的序列长度;SP+ZeRO3 可实现与 TP-sp 相近的序列长度。
    - 建议采用更高程度的 SP 并行:当注意力头数受限时,可设置较大的 ring 维度,把长序列拆到更多计算设备上训练——这是 TP-sp 无法实现的独特优势。


VAE序列并行
^^^^^^^^^^^

在生成模型的Encoder阶段,主要的动态显存瓶颈在于VAE对视频数据的编码,尤其在处理高分辨率和帧数的视频时。
对于主流的VAE而言,可以利用conv的计算局部性,结合通信算法,使用多卡分担一条数据的计算。
但是由于各个模型的实现不同,具体算法也有所区别。MM内置的OpenSoraPlan1.3和CogVideoX等模型,适配了切分算法,使能方式和约束见readme说明。

- 收益:降低VAE encode阶段动态显存瓶颈。
- 代价:增加通信。部分算法受帧数等实际参数约束。

VAE分块处理
^^^^^^^^^^

解决VAE激活值瓶颈的另一个方法是开启分块处理。业界当前已经有比较成熟的分块处理算法。
VAE 分块处理通过将图像分割成多个较小的重叠图块,一次处理一块,而不是一次性处理整张图像,从而节省内存。

.. note::

    分块处理(tiling)不会显著降低生成质量,但是和原始计算并不等价

MM内置模型可以在model.json中开启分块处理。

.. code-block:: json5

    {
        "ae": {
            "enable_tiling": true # 部分模型为"use_tiling"
        }
    }

- 收益:显著降低VAE encode阶段动态显存瓶颈。
- 代价:计算不等价,下发的计算小而多,且可能包含重叠分块,计算耗时往往会增加

激活值卸载
^^^^^^^^

激活值卸载也是常用的做法,一般会把重计算的中间结果,或者cache等生命周期相对较长的激活值进行offload。

最简单的,可以直接使用to来实现offload。

.. code-block:: python

    for cache in caches:
        for k in cache:
            cache[k] = cache[k].to("cpu")
    for i in range(len(caches)):
        for k in caches[i]:
            caches[i] = caches[i].to("npu")
        run_your_code(**caches[i])

我们对一些场景的offload做了异步优化,可以参考async offload特性文档。

.. note::

    需要注意的是,torch的tensor生命周期管理复用了python的引用管理,因此在offload时需要确认卸载tensor的引用已清空,否则空间将不会被回收。

loss分块计算
^^^^^^^^^^

在典型自回归语言模型中,loss计算前通过LM Head将每个token映射到词表大小,然后通过softmax计算出每个词的概率,用于和真实结果计算出loss数值。
此计算过程会产生大量激活值,尤其在词表、序列长度都很大的情况下。
MM对LM Head应用列切,并对激活值使用动态softmax等算法,计算loss。
用户还可以使用chunkloss特性,利用序列loss互不依赖的特点,将序列维度切块计算loss并累积,进一步降低显存占用。

- 对词表维度的切分。

  - 收益:降低激活值显存占用,TP越大收益越大。
  - 代价:仅在TP启用时有效。

- 对序列维度分块计算累积loss。

  - 收益:显著降低loss计算阶段激活值显存占用。
  - 代价:下发计算变碎,降低硬件利用率。



参考资料
------------