From bf5d0c83bbbdfbac376bddf0b3b426a4f6cd3035 Mon Sep 17 00:00:00 2001
From: caojingyi <caojingyi2@huawei.com>
Date: Mon, 8 Dec 2025 08:49:09 +0800
Subject: [PATCH] dapo data rebalance
Introduces `data_rebalance` to mitigate long-tail load imbalance in multi-NPU inference.
Disables repeat interleaving during batch repetition, and applies fixed reordering after generation
to reduce per-NPU latency variance.
llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py | 14 ++++++++++++--
1 file changed, 12 insertions(+), 2 deletions(-)
@@ -132,7 +132,14 @@ class RayDAPOTrainer(RayPPOTrainer):
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
- gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+
+ data_rebalance = self.config.actor_rollout_ref.rollout.data_rebalance if hasattr(
+ self.config.actor_rollout_ref.rollout, 'data_rebalance') else True
+ gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n,
+ interleave=not data_rebalance)
+ if data_rebalance:
+ interleave_indices = torch.arange(gen_batch.batch.batch_size[0]).view(
+ -1, new_batch.batch.batch_size[0]).transpose(1, 0).reshape(-1)
is_last_step = self.global_steps >= self.total_training_steps
@@ -140,6 +147,8 @@ class RayDAPOTrainer(RayPPOTrainer):
# generate a batch
with marked_timer("gen", timing_raw, "red"):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
+ if data_rebalance:
+ gen_batch_output.reorder(interleave_indices)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
@@ -148,7 +157,8 @@ class RayDAPOTrainer(RayPPOTrainer):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
-
+ if data_rebalance:
+ gen_baseline_output.reorder(interleave_indices)
new_batch = new_batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(new_batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
--
2.50.1.windows.1