Collecting Data for Verifying Data Consistency Between verl Training and Inference Based on Megatron
Overview
Before comparing data consistency between verl training and inference processes, ensure that the input shapes for training and inference are the same. This ensures that the precision data dumped during training and inference can be matched during comparison.
Input Alignment Analysis for verl Training and Inference
Generally, the input shapes during training and inference are different, due to characteristics of training and inference processes themselves.
-
The inference process is divided into two stages: prefill and k decode.
- In the prefill stage, the inference input is a prompt.
- In the k decode stage, the final inference output response is generated by adding the KV cache with the output token obtained from the previous decode operation.
-
During training, the input is a prompt plus the inference response, and the final output is
logits.
In conclusion, the inference input is a prompt, and the training input is the prompt plus the inference output response.
Conclusion: The training input needs to be adjusted to a single prompt, making it consistent with the inference input.
Preprocessing Operations
To ensure that the training forward shape matches the inference prefill shape, the response must be removed from the training input. This requires meeting the following prerequisites and modifying the training script accordingly.
-
Ensure that the batch size in training is not split.
-
Ensure that the number of mini batches used for gradient update in each training epoch is
mini_batch_num = 1.Formula:
mini_batch_num=train_batch_size/train_ppo_mini_batch_sizetrain_batch_size: total number of samples in training.train_ppo_mini_batch_size: number of samples in each mini batch.
-
Ensure
gac (Gradient Accumulation Steps) = 1.Formula:
gac=train_ppo_mini_batch_size*n_resp_per_prompt/train_ppo_micro_batch_size_per_gpu/DPtrain_ppo_mini_batch_size: number of samples in each mini batch.n_resp_per_prompt: number of responses per prompt.train_ppo_micro_batch_size_per_gpu: size of the micro batch processed on each GPU.DP: data parallelism degree. Formula:DP=world_size/TP/PP/CP
Modify the preceding parameters in the script as follows:
data.train_batch_size=${train_batch_size} actor_rollout_ref.actor.ppo_mini_batch_size=${train_ppo_mini_batch_size} actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} actor_rollout_ref.rollout.n=${n_resp_per_prompt} -
-
Ensure that no padding occurs during training.
use_remove_padding: Remove padding for optimization.actor_rollout_ref.model.use_remove_padding=True -
Modify the environment variables in the training script.
export DUMP_ON=1 export PROMPTS_ONLY=1 -
Ensure that the data collected during training and inference corresponds one-to-one on each rank.
balance_batchcan be used to automatically balance and evenly divide batch data.trainer.balance_batch=False
Modification to verl Code
To delete the response from the training input, you need to modify verl/workers/actor/megatron_actor.py, verl/utils/debug/metrics.py, and verl/trainer/ppo/rollout_corr_helper.py. The following uses release/v0.6.1 as an example. The modifications are highlighted as follows:
verl/workers/actor/megatron_actor.py
...
...
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""..."""
...
+ # Check whether log_probs are calculated only for prompts (excluding responses).
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ # Remove the response part from input_ids, attention_mask, and position_ids.
+ if "responses" in data.batch:
+ response_length = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-response_length]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-response_length]
+ if data.batch["position_ids"].dim() == 3: # qwen2vl mrope
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-response_length]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-response_length]
+ # Remove the response from a batch.
+ data.batch.pop("responses", None)
+ if "rollout_log_probs" in data.batch:
+ data.batch.pop("rollout_log_probs", None)
+ if "response_mask" in data.batch:
+ data.batch.pop("response_mask", None)
def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
- response = data["responses"]
- response_length = response.size(1)
- log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous()
+ if "responses" in data and data["responses"] is not None:
+ response = data["responses"]
+ response_length = response.size(1)
+ log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous()
+ else:
+ # For prompts: Return log_probs of all prompt tokens (excluding the last token used for next token prediction).
+ log_probs = output["log_probs"][:, :-1].contiguous()
return {"log_probs": log_probs}
...
if recompute_old_log_prob:
- select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
+ # By default, recompute_old_log_prob is used.
+ select_keys = ["input_ids", "attention_mask", "position_ids"]
+ if "responses" in data.batch:
+ select_keys.append("responses")
batch = data.select(batch_keys=select_keys).batch
input_ids = batch["input_ids"]
batch_size = input_ids.size(0)
- response = batch["responses"]
- response_length = response.size(1)
+ if "responses" in batch and batch["responses"] is not None:
+ response = batch["responses"]
+ response_length = response.size(1)
+ else:
+ response = None
+ response_length = 0
with torch.no_grad():
...
- log_probs = torch.empty(
- size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
- )
+ # For prompts: The shape of log_probs is [batch_size, prompt_length – 1] (excluding the last token used for next token prediction).
+ if response_length > 0:
+ log_probs_shape = (batch_size, response_length)
+ else:
+ prompt_length = input_ids.size(1)
+ log_probs_shape = (batch_size, prompt_length - 1) if prompt_length > 1 else (batch_size, 0)
+ log_probs = torch.empty(
+ size=log_probs_shape, dtype=torch.float32, device=input_ids.device
+ )
...
- entropys = torch.empty(
- size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
- )
+ if response_length > 0:
+ entropy_shape = (batch_size, response_length)
+ else:
+ prompt_length = input_ids.size(1)
+ entropy_shape = (batch_size, 0)
+ entropys = torch.empty(
+ size=entropy_shape, dtype=torch.float32, device=input_ids.device
+ )
...
def forward_backward_batch(...):
"""..."""
...
def loss_func(output, data, meta_info):
...
- responses = data["responses"]
- response_length = responses.size(1)
- response_mask = data["response_mask"].to(bool)
- loss_agg_mode = self.config.loss_agg_mode
- # compute policy loss
- log_prob = log_probs[:, -response_length - 1 : -1].contiguous()
+ # Check whether there is a response or only a prompt.
+ if "responses" in data and data["responses"] is not None:
+ responses = data["responses"]
+ response_length = responses.size(1)
+ response_mask = data["response_mask"].to(bool)
+ # Calculate policy loss.
+ log_prob = log_probs[:, -response_length - 1 : -1].contiguous()
+ else:
+ # For prompts: Use all log_probs except the last token.
+ response_length = 0
+ response_mask = None
+ log_prob = log_probs[:, :-1].contiguous() if log_probs.size(1) > 0 else log_probs
+ loss_agg_mode = self.config.loss_agg_mode
...
rollout_is_weights = data.get("rollout_is_weights", None)
+ # For prompts: Create a mask for all prompt tokens (excluding the last one).
+ if response_mask is None:
+ # Create a mask for all prompt tokens except the last one.
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
...
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs
+ # For prompts: Use the same mask as above.
+ if response_mask is None:
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
...
if calculate_entropy:
- entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
- if not forward_only:
- entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
+ if response_length > 0:
+ entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
+ else:
+ # For prompts: Use all token entropy except the last one.
+ entropy = output["entropy"][:, :-1].contiguous() if output["entropy"].size(1) > 0 else output["entropy"]
+ if not forward_only:
+ if response_mask is not None:
+ entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
+ else:
+ entropy_loss = agg_loss(loss_mat=entropy, loss_mask=None, loss_agg_mode=loss_agg_mode)
...
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
- kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
+ if response_mask is not None:
+ kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
+ else:
+ kl_loss = agg_loss(loss_mat=kld, loss_mask=None, loss_agg_mode=self.config.loss_agg_mode)
...
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
...
- responses = batch["responses"]
- response_length = responses.size(1)
- label = position_ids.clone()
- label[:, -response_length - 1 : -1] = responses
- label_mask = attention_mask.clone()
- label_mask[:, : -response_length - 1] = False
- label_mask[:, -1] = False
+ # Check whether there is a response or only a prompt.
+ if "responses" in batch and batch["responses"] is not None:
+ responses = batch["responses"]
+ response_length = responses.size(1)
+ label = position_ids.clone()
+ label[:, -response_length - 1 : -1] = responses
+ label_mask = attention_mask.clone()
+ label_mask[:, : -response_length - 1] = False
+ label_mask[:, -1] = False
+ else:
+ # For prompts: Calculate log_probs of all prompt tokens (for next token prediction).
+ # Labels are shifted one position back from input_ids for next token prediction.
+ response_length = 0
+ label = input_ids.clone()
+ label_mask = attention_mask.clone()
+ # For prompts: Calculate the log probabilities of all tokens except the last one.
+ # (There is no next token for the last token to predict.)
+ if label_mask.size(1) > 0:
+ label_mask[:, -1] = False
verl/utils/debug/metrics.py
...
def calculate_debug_metrics(data: DataProto) -> dict:
"""..."""
+ if "rollout_log_probs" not in data.batch:
+ logger.warning("rollout_log_probs not found in batch, skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
+ if "old_log_probs" not in data.batch:
+ logger.warning("old_log_probs not found in batch, skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
+ if "responses" not in data.batch:
+ logger.warning(
+ "responses not found in batch(possibly compute_prompts_only mode), skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
rollout_old_log_probs = data.batch["rollout_log_probs"]
verl/trainer/ppo/rollout_corr_helper.py
...
def compute_rollout_correction_and_add_to_batch(
batch: DataProto, rollout_corr_config: RolloutCorrectionConfig
) -> tuple[DataProto, dict]:
"""..."""
+ if int(os.getenv("PROMPTS_ONLY", "0")):
+ return batch, {}
rollout_is = rollout_corr_config.get("rollout_is", None)
Data Collection
Add the PrecisionDebugger API of the msProbe tool to the verl/workers/megatron_workers.py file to perform data dump. For details about PrecisionDebugger, see Precision Data Collection in PyTorch.
Modify code as highlighted in the following example:
...
class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
"""..."""
def __init__(self, config: DictConfig, role: str, **kwargs):
...
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)
+ # Modification to __init__
+ # Instantiate PrecisionDebugger.
+ # Set the environment variable DUMP_ON to quickly enable or disable the dump function.
+ dump_flag = int(os.environ.get("DUMP_ON", 0))
+ if dump_flag:
+ from msprobe.pytorch import PrecisionDebugger, seed_all
+ seed_all(mode=True)
+ self.debugger = PrecisionDebugger(task='tensor', level='L0', step=[0], dump_path='0_dump_path/')
+ self.dump_path_prefix = self.debugger.config.dump_path
+ else:
+ self.debugger = None
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
@GPUMemoryLogger(role="generate_sequences", logger=logger)
@DistProfiler.annotate(color="red")
def generate_sequences(self, prompts: DataProto):
...
with simple_timer("generate_sequences", timing_generate):
+ # Enable the tool's dump function in generate_sequences to collect forward inference data.
+ if self.debugger:
+ infer_model = self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
+ # The forward dump data in the inference phase is stored in the generate_sequences folder.
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'generate_sequences')
+ self.debugger.start(model=infer_model, token_range=[0, 0])
+ output = self.rollout.generate_sequences(prompts=prompts)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.service._reset_status()
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="compute_log_prob", logger=logger)
@DistProfiler.annotate(color="blue")
def compute_log_prob(self, data: DataProto):
...
data.meta_info["temperature"] = self.config.rollout.temperature
+ # Enable the tool's dump function in compute_log_prob to collect the module-level input and output data in the training phase.
+ if self.debugger:
+ # The dump data in the training phase is stored in the update_actor folder.
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'update_actor')
+ self.debugger.start(model=self.actor.actor_module)
+ output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.step()
...