From 52ae67da007d2de872fd2669b0f65d0de7b6be97 Mon Sep 17 00:00:00 2001
From: caojingyi <caojingyi@noreply.gitcode.com>
Date: Tue, 18 Nov 2025 11:55:45 +0800
Subject: [PATCH] Update verl: data rebalance
Introduces `data_rebalance` to mitigate long-tail load imbalance in multi-NPU inference.
Disables repeat interleaving during batch repetition, and applies fixed reordering after generation
to reduce per-NPU latency variance.

---
 llm_rl/qwen3/verl/trainer/ppo/ray_trainer.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/llm_rl/qwen3/verl/trainer/ppo/ray_trainer.py b/llm_rl/qwen3/verl/trainer/ppo/ray_trainer.py
index 6f50a6e..0abdd34 100644
--- a/llm_rl/qwen3/verl/trainer/ppo/ray_trainer.py
+++ b/llm_rl/qwen3/verl/trainer/ppo/ray_trainer.py
@@ -1037,7 +1037,14 @@ class RayPPOTrainer:
 
                 # pass global_steps to trace
                 gen_batch.meta_info["global_steps"] = self.global_steps
-                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
+
+                data_rebalance = self.config.actor_rollout_ref.rollout.data_rebalance if hasattr(
+                    self.config.actor_rollout_ref.rollout, 'data_rebalance') else True
+                gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n,
+                                             interleave=not data_rebalance)
+                if data_rebalance:
+                    interleave_indices = torch.arange(gen_batch.batch.batch_size[0]).view(
+                        -1, batch.batch.batch_size[0]).transpose(1, 0).reshape(-1)
 
                 is_last_step = self.global_steps >= self.total_training_steps
                 with marked_timer("step", timing_raw):
@@ -1048,6 +1055,9 @@ class RayPPOTrainer:
                         else:
                             gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
 
+                        if data_rebalance:
+                            gen_batch_output.reorder(interleave_indices)
+                        
                         timing_raw.update(gen_batch_output.meta_info["timing"])
                         gen_batch_output.meta_info.pop("timing", None)
 
@@ -1062,6 +1072,10 @@ class RayPPOTrainer:
                                 gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
                             else:
                                 gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
+                            
+                            if data_rebalance:
+                                gen_baseline_output.reorder(interleave_indices)
+
                             batch = batch.union(gen_baseline_output)
                             reward_baseline_tensor = self.reward_fn(batch)
                             reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
-- 
2.50.1.windows.1