online-data-balance
=======================
1. 背景
-------------
多模态模型由于原生分辨率等因素,在分布式训练中,数据负载不均,集群资源利用率下降。
Data Balance 模块旨在解决数据负载不均问题,通过在线数据重排布,实现DP间负载均衡,提升资源利用率。
2. 解决方案
-------------
注:当前支持ViT数据负载均衡,LLM部分未使用packing策略,后续将支持LLM部分packing,结合当前ViT部分在线负载均衡方案,提升整网资源利用率。
2.1 关键组件
::::::::::::::::
.. list-table:: 关键组件说明
:widths: 30 70
:header-rows: 1
* - 组件
- 说明
* - **DataBalance 类**
- 协调数据平衡流程的主类
* - **排序算法**
- 如 `post_global_balancing_greedy_without_pad`(后续将支持更多排序算法)
* - **PrefetchMicroBatchIterator**
- 异步预取数据,隐藏延迟
* - **辅助函数**
- 数据分割、映射、重组和通信等功能
2.2 执行流程
:::::::::::::::::
.. mermaid::
flowchart TD
A[开始训练] --> B[获取原始批次(GBS)]
B --> C[DP间数据长度获取]
C --> D[负载均衡数据索引重排]
D --> E[ranktable mapping]
E --> F[All2All数据分发]
F --> G[负载均衡迭代器构建]
G --> H[训练步骤]
H --> I[创建负载均衡迭代器]
I --> J[前反向、梯度更新]
J --> K{是否下次迭代}
K -->|是| B
K -->|否| L[结束训练]
2.2.1 数据分发样例
^^^^^^^^^^^^^^^^^^^^
紫色为初始状态,蓝色为目标状态。通过数据映射路径与All2All通信,实现数据的重排布。
.. image:: ../_static/features/online-data-balance/all2all.png
:alt: all2all
3. 使用指南
---------------
3.1 启用配置
::::::::::::::::
在Qwen 2.5 Omni模型训练命令中添加:
.. code-block:: bash
--use-data-balance
3.2 核心 API
::::::::::::::::::
3.2.1 DataBalance 初始化
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
from mindspeed_mm.utils.data_balance.data_balance import DataBalance
data_balancer = DataBalance(
virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
model_config_path=args.model_config_path,
sorting_algo_name=args.sorting_algo_name,
len_model=len_model,
train_data_iterator=train_data_iterator
)
3.2.2 创建负载均衡数据迭代器
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code-block:: python
batch_generator = data_balancer.build_balanced_train_data_iterator(
is_vit_last_stage=is_vit_last_stage,
max_batch_capacity=max_batch_capacity,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
data_type='image'
)
# 在训练循环中使用
for batch in batch_generator:
# 训练步骤
...
4. 最佳实践
----------------
- **算法选择**:默认使用 `post_global_balancing_greedy_without_pad`(后续将支持更多算法)
- **并行配置**:确保图像编码器 DP 度 > 1
5. 性能收益
-----------------
- **训练速度**:计算负载均衡,缩短训练时间,典型场景收益5%+。