From bf5d0c83bbbdfbac376bddf0b3b426a4f6cd3035 Mon Sep 17 00:00:00 2001
From: caojingyi <caojingyi2@huawei.com>
Date: Mon, 8 Dec 2025 08:49:09 +0800
Subject: [PATCH] dapo data rebalance

Introduces `data_rebalance` to mitigate long-tail load imbalance in multi-NPU inference.
Disables repeat interleaving during batch repetition, and applies fixed reordering after generation
to reduce per-NPU latency variance.
---
 llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py | 14 ++++++++++++--
 1 file changed, 12 insertions(+), 2 deletions(-)

diff --git a/llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py b/llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py
index 7ad7372..5420d0a 100644
--- a/llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py
+++ b/llm_rl/qwen3/verl/trainer/dapo_ray_trainer.py
@@ -132,7 +132,14 @@ class RayDAPOTrainer(RayPPOTrainer):
                         batch_keys=["input_ids", "attention_mask", "position_ids"],
                         non_tensor_batch_keys=["raw_prompt_ids"],
                     )
-                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+
+                data_rebalance = self.config.actor_rollout_ref.rollout.data_rebalance if hasattr(
+                    self.config.actor_rollout_ref.rollout, 'data_rebalance') else True
+                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n,
+                                             interleave=not data_rebalance)
+                if data_rebalance:
+                    interleave_indices = torch.arange(gen_batch.batch.batch_size[0]).view(
+                        -1, new_batch.batch.batch_size[0]).transpose(1, 0).reshape(-1)
 
                 is_last_step = self.global_steps >= self.total_training_steps
 
@@ -140,6 +147,8 @@ class RayDAPOTrainer(RayPPOTrainer):
                     # generate a batch
                     with marked_timer("gen", timing_raw, "red"):
                         gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
+                        if data_rebalance:
+                            gen_batch_output.reorder(interleave_indices)
                         timing_raw.update(gen_batch_output.meta_info["timing"])
                         gen_batch_output.meta_info.pop("timing", None)
 
@@ -148,7 +157,8 @@ class RayDAPOTrainer(RayPPOTrainer):
                             gen_baseline_batch = deepcopy(gen_batch)
                             gen_baseline_batch.meta_info["do_sample"] = False
                             gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
-
+                            if data_rebalance:
+                                gen_baseline_output.reorder(interleave_indices)
                             new_batch = new_batch.union(gen_baseline_output)
                             reward_baseline_tensor = self.reward_fn(new_batch)
                             reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
-- 
2.50.1.windows.1