Data balance

1. 背景

多模态模型由于原生分辨率等因素,在分布式训练中,数据负载不均,集群资源利用率下降。

Data Balance 模块旨在解决数据负载不均问题,通过在线数据重排布,实现DP间负载均衡,提升资源利用率。

2. 解决方案

注:当前支持ViT数据负载均衡,LLM部分未使用packing策略,后续将支持LLM部分packing,结合当前ViT部分在线负载均衡方案,提升整网资源利用率。

2.1 关键组件

组件 说明
DataBalance 类 协调数据平衡流程的主类
排序算法 post_global_balancing_greedy_without_pad(后续将支持更多排序算法)
PrefetchMicroBatchIterator 异步预取数据,隐藏延迟
辅助函数 数据分割、映射、重组和通信等功能

2.2 执行流程

flowchart TD
    A[开始训练] --> B[获取原始批次(GBS)]
    B --> C[DP间数据长度获取]
    C --> D[负载均衡数据索引重排]
    D --> E[ranktable mapping]
    E --> F[All2All数据分发]
    F --> G[负载均衡迭代器构建]
    G --> H[训练步骤]
    H --> I[创建负载均衡迭代器]
    I --> J[前反向、梯度更新]
    J --> K{是否下次迭代}
    K -->|是| B
    K -->|否| L[结束训练]

2.2.1 数据分发样例

紫色为初始状态,蓝色为目标状态。通过数据映射路径与All2All通信,实现数据的重排布。

all2all

3. 使用指南

3.1 启用配置

在Qwen 2.5 Omni模型训练命令中添加:

--use-data-balance

3.2 核心 API

3.2.1 DataBalance 初始化

from mindspeed_mm.utils.data_balance.data_balance import DataBalance

data_balancer = DataBalance(
    virtual_pipeline_model_parallel_size=args.virtual_pipeline_model_parallel_size,
    model_config_path=args.model_config_path,
    sorting_algo_name=args.sorting_algo_name,
    len_model=len_model,
    train_data_iterator=train_data_iterator
)

3.2.2 创建负载均衡数据迭代器

batch_generator = data_balancer.build_balanced_train_data_iterator(
    is_vit_last_stage=is_vit_last_stage,
    max_batch_capacity=max_batch_capacity,
    micro_batch_size=micro_batch_size,
    num_microbatches=num_microbatches,
    data_type='image'
)

# 在训练循环中使用
for batch in batch_generator:
    # 训练步骤
    ...

5. 最佳实践

  • 算法选择:默认使用 post_global_balancing_greedy_without_pad(后续将支持更多算法)
  • 并行配置:确保图像编码器 DP 度 > 1

6. 性能收益

  • 训练速度:计算负载均衡,缩短训练时间,典型场景收益5%+。