online-data-balance
=======================

1. 背景
-------------

多模态模型由于原生分辨率等因素,在分布式训练中,数据负载不均,集群资源利用率下降。

Data Balance 模块旨在解决数据负载不均问题,通过在线数据重排布,实现DP间负载均衡,提升资源利用率。

2. 解决方案
-------------

注:当前支持ViT数据负载均衡,LLM部分未使用packing策略,后续将支持LLM部分packing,结合当前ViT部分在线负载均衡方案,提升整网资源利用率。

2.1 关键组件
::::::::::::::::

.. list-table:: 关键组件说明
   :widths: 30 70
   :header-rows: 1

   * - 组件
     - 说明
   * - **DataBalance 类**
     - 协调数据平衡流程的主类
   * - **排序算法**
     - 如 `post_global_balancing_greedy_without_pad`(后续将支持更多排序算法)
   * - **PrefetchMicroBatchIterator**
     - 异步预取数据,隐藏延迟
   * - **辅助函数**
     - 数据分割、映射、重组和通信等功能

2.2 执行流程
:::::::::::::::::

.. mermaid::

    flowchart TD
        A[开始训练] --> B[获取原始批次(GBS)]
        B --> C[DP间数据长度获取]
        C --> D[负载均衡数据索引重排]
        D --> E[ranktable mapping]
        E --> F[All2All数据分发]
        F --> G[负载均衡迭代器构建]
        G --> H[训练步骤]
        H --> I[创建负载均衡迭代器]
        I --> J[前反向、梯度更新]
        J --> K{是否下次迭代}
        K -->|是| B
        K -->|否| L[结束训练]

2.2.1 数据分发样例
^^^^^^^^^^^^^^^^^^^^

紫色为初始状态,蓝色为目标状态。通过数据映射路径与All2All通信,实现数据的重排布。

.. image:: ../_static/features/online-data-balance/all2all.png
   :alt: all2all

3. 使用指南
---------------

3.1 启用配置
::::::::::::::::

在Qwen 2.5 Omni模型训练命令中添加:

.. code-block:: bash

    --use-data-balance

3.2 核心 API
::::::::::::::::::

3.2.1 DataBalance 初始化
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: python

    from mindspeed_mm.utils.data_balance.data_balance import DataBalance

    data_balancer = DataBalance(
        virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
        model_config_path=args.model_config_path,
        sorting_algo_name=args.sorting_algo_name,
        len_model=len_model,
        train_data_iterator=train_data_iterator
    )

3.2.2 创建负载均衡数据迭代器
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. code-block:: python

    batch_generator = data_balancer.build_balanced_train_data_iterator(
        is_vit_last_stage=is_vit_last_stage,
        max_batch_capacity=max_batch_capacity,
        micro_batch_size=micro_batch_size,
        num_microbatches=num_microbatches,
        data_type='image'
    )

    # 在训练循环中使用
    for batch in batch_generator:
        # 训练步骤
        ...

4. 最佳实践
----------------

- **算法选择**:默认使用 `post_global_balancing_greedy_without_pad`(后续将支持更多算法)
- **并行配置**:确保图像编码器 DP 度 > 1

5. 性能收益
-----------------

- **训练速度**:计算负载均衡,缩短训练时间,典型场景收益5%+。