megatron训练后端verl训推一致性比对数据采集
简介
verl训推一致性比对场景在比对前,需要保证训练和推理时的输入 shape 一致,才能确保比对时训练和推理dump的精度数据匹配。
verl训练推理输入对齐分析
一般情况下,训练和推理时的输入 shape 不一致,从训练推理运行原理分析。
-
推理运行分为 prefill 和 k 个 decode 两个步骤:
- 在 prefill 步骤时,推理输入为 prompt。
- 在 k 个 decode 步骤时,kv cache 加上上一个 decode 得到的输出 token ,最终输出推理的 response。
-
训练运行时,输入为 prompt 加上推理输出的 response,最终输出 logits。
综合以上信息,推理输入为 prompt;训练输入为 prompt 加上推理输出的 response。
结论:需要将训练的输入调整为单 prompt,与推理的输入保持一致。
前置操作
要保证训练 forward 和推理 prefill 的 shape 一致,需要去掉训练输入中的 response,首先需要满足如下前提,并修改训练脚本:
-
保证训练中的batch size维度未被拆分。
-
需保证每轮训练中用于梯度更新的mini batch个数mini_batch_num = 1
计算公式为:mini_batch_num = train_batch_size / train_ppo_mini_batch_size
- train_batch_size: 训练中总的样本数。
- train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
-
需保证梯度累计步骤数gac (Gradient Accumulation Steps) = 1
计算公式为:gac = train_ppo_mini_batch_size * n_resp_per_prompt / train_ppo_micro_batch_size_per_gpu / DP
- train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
- n_resp_per_prompt: 每个提示(prompt)下的响应数。
- train_ppo_micro_batch_size_per_gpu: 每个GPU上处理的 micro batch 大小。
- DP: 数据并行度。DP = world_size / TP / PP / CP
上述参数在脚本中具体修改为:
data.train_batch_size=${train_batch_size} actor_rollout_ref.actor.ppo_mini_batch_size=${train_ppo_mini_batch_size} actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} actor_rollout_ref.rollout.n=${n_resp_per_prompt} -
-
保证训练中无pad。 use_remove_padding: 移除 padding 优化
actor_rollout_ref.model.use_remove_padding=True -
在训练脚本中修改环境变量。
export DUMP_ON=1 export PROMPTS_ONLY=1 -
保证训练和推理采集的数据在每张卡上是一一对应的。 balance_batch: 自动平衡、均分batch数据
trainer.balance_batch=False
verl代码修改
去掉训练输入中的 response,需要修改 verl/workers/actor/megatron_actor.py、verl/utils/debug/metrics.py、verl/trainer/ppo/rollout_corr_helper.py,以 release/v0.6.1 为例,修改处高亮显示如下:
verl/workers/actor/megatron_actor.py
...
...
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""..."""
...
+ # 检查是否仅对提示计算 log_probs(不包括响应)
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ # 从 input_ids、attention_mask 和 position_ids 中移除响应部分
+ if "responses" in data.batch:
+ response_length = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-response_length]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-response_length]
+ if data.batch["position_ids"].dim() == 3: # qwen2vl mrope
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-response_length]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-response_length]
+ # 从批处理中移除响应
+ data.batch.pop("responses", None)
+ if "rollout_log_probs" in data.batch:
+ data.batch.pop("rollout_log_probs", None)
+ if "response_mask" in data.batch:
+ data.batch.pop("response_mask", None)
def compute_logprobs_fn(output, data, use_dynamic_bsz=False, indices=None):
- response = data["responses"]
- response_length = response.size(1)
- log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous()
+ if "responses" in data and data["responses"] is not None:
+ response = data["responses"]
+ response_length = response.size(1)
+ log_probs = output["log_probs"][:, -response_length - 1 : -1].contiguous()
+ else:
+ # 仅针对提示,返回所有提示 token 的 log_probs(不包括用于下一个 token 预测的最后一个 token)
+ log_probs = output["log_probs"][:, :-1].contiguous()
return {"log_probs": log_probs}
...
if recompute_old_log_prob:
- select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
+ # 这里默认使用 recompute_old_log_prob。
+ select_keys = ["input_ids", "attention_mask", "position_ids"]
+ if "responses" in data.batch:
+ select_keys.append("responses")
batch = data.select(batch_keys=select_keys).batch
input_ids = batch["input_ids"]
batch_size = input_ids.size(0)
- response = batch["responses"]
- response_length = response.size(1)
+ if "responses" in batch and batch["responses"] is not None:
+ response = batch["responses"]
+ response_length = response.size(1)
+ else:
+ response = None
+ response_length = 0
with torch.no_grad():
...
- log_probs = torch.empty(
- size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
- )
+ # 仅用于提示,log_probs 的形状为 [batch_size, prompt_length - 1](不包括用于下一个 token 预测的最后一个 token)
+ if response_length > 0:
+ log_probs_shape = (batch_size, response_length)
+ else:
+ prompt_length = input_ids.size(1)
+ log_probs_shape = (batch_size, prompt_length - 1) if prompt_length > 1 else (batch_size, 0)
+ log_probs = torch.empty(
+ size=log_probs_shape, dtype=torch.float32, device=input_ids.device
+ )
...
- entropys = torch.empty(
- size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
- )
+ if response_length > 0:
+ entropy_shape = (batch_size, response_length)
+ else:
+ prompt_length = input_ids.size(1)
+ entropy_shape = (batch_size, 0)
+ entropys = torch.empty(
+ size=entropy_shape, dtype=torch.float32, device=input_ids.device
+ )
...
def forward_backward_batch(...):
"""..."""
...
def loss_func(output, data, meta_info):
...
- responses = data["responses"]
- response_length = responses.size(1)
- response_mask = data["response_mask"].to(bool)
- loss_agg_mode = self.config.loss_agg_mode
- # compute policy loss
- log_prob = log_probs[:, -response_length - 1 : -1].contiguous()
+ # 检查是否有响应或只有提示
+ if "responses" in data and data["responses"] is not None:
+ responses = data["responses"]
+ response_length = responses.size(1)
+ response_mask = data["response_mask"].to(bool)
+ # 计算策略损失
+ log_prob = log_probs[:, -response_length - 1 : -1].contiguous()
+ else:
+ # 仅用于提示:使用除最后一个标记外的所有 log_probs
+ response_length = 0
+ response_mask = None
+ log_prob = log_probs[:, :-1].contiguous() if log_probs.size(1) > 0 else log_probs
+ loss_agg_mode = self.config.loss_agg_mode
...
rollout_is_weights = data.get("rollout_is_weights", None)
+ # 仅用于提示,为所有提示标记(不包括最后一个)创建掩码
+ if response_mask is None:
+ # 为除最后一个之外的所有提示标记创建掩码
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
...
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_corr_metrics_from_logprobs
+ # 仅用于提示,使用与上面相同的掩码
+ if response_mask is None:
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
...
if calculate_entropy:
- entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
- if not forward_only:
- entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
+ if response_length > 0:
+ entropy = output["entropy"][:, -response_length - 1 : -1].contiguous()
+ else:
+ # 仅用于提示:使用除最后一个标记外的所有熵
+ entropy = output["entropy"][:, :-1].contiguous() if output["entropy"].size(1) > 0 else output["entropy"]
+ if not forward_only:
+ if response_mask is not None:
+ entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
+ else:
+ entropy_loss = agg_loss(loss_mat=entropy, loss_mask=None, loss_agg_mode=loss_agg_mode)
...
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
- kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
+ if response_mask is not None:
+ kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
+ else:
+ kl_loss = agg_loss(loss_mat=kld, loss_mask=None, loss_agg_mode=self.config.loss_agg_mode)
...
def forward_step(batch_iter, model, return_schedule_plan: bool = False):
...
- responses = batch["responses"]
- response_length = responses.size(1)
- label = position_ids.clone()
- label[:, -response_length - 1 : -1] = responses
- label_mask = attention_mask.clone()
- label_mask[:, : -response_length - 1] = False
- label_mask[:, -1] = False
+ # 检查是否有响应或只有提示
+ if "responses" in batch and batch["responses"] is not None:
+ responses = batch["responses"]
+ response_length = responses.size(1)
+ label = position_ids.clone()
+ label[:, -response_length - 1 : -1] = responses
+ label_mask = attention_mask.clone()
+ label_mask[:, : -response_length - 1] = False
+ label_mask[:, -1] = False
+ else:
+ # 仅针对提示:计算所有提示令牌的 log_probs(下一个令牌预测)
+ # 标签是将 input_ids 向后移动 1 个位置以进行下一个令牌预测
+ response_length = 0
+ label = input_ids.clone()
+ label_mask = attention_mask.clone()
+ # 仅针对提示,计算除最后一个之外的所有 token 的对数概率
+ # (因为最后一个 token 没有下一个 token 可以预测)
+ if label_mask.size(1) > 0:
+ label_mask[:, -1] = False
verl/utils/debug/metrics.py
...
def calculate_debug_metrics(data: DataProto) -> dict:
"""..."""
+ if "rollout_log_probs" not in data.batch:
+ logger.warning("rollout_log_probs not found in batch, skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
+ if "old_log_probs" not in data.batch:
+ logger.warning("old_log_probs not found in batch, skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
+ if "responses" not in data.batch:
+ logger.warning(
+ "responses not found in batch(possibly compute_prompts_only mode), skipping debug metrics calculation")
+ return {
+ "training/rollout_probs_diff_valid": 0,
+ "training/rollout_probs_diff_max": 0.0,
+ "training/rollout_probs_diff_mean": 0.0,
+ "training/rollout_probs_diff_std": 0.0,
+ "training/rollout_actor_probs_pearson_corr": 0.0,
+ }
+
rollout_old_log_probs = data.batch["rollout_log_probs"]
verl/trainer/ppo/rollout_corr_helper.py
...
def compute_rollout_correction_and_add_to_batch(
batch: DataProto, rollout_corr_config: RolloutCorrectionConfig
) -> tuple[DataProto, dict]:
"""..."""
+ if int(os.getenv("PROMPTS_ONLY", "0")):
+ return batch, {}
rollout_is = rollout_corr_config.get("rollout_is", None)
数据采集
在 verl/workers/megatron_workers.py 中添加msProbe工具的PrecisionDebugger接口进行dump操作。PrecisionDebugger接口更多介绍请参见《PyTorch场景精度数据采集》。
修改示例代码高亮显示如下:
...
class ActorRolloutRefWorker(MegatronWorker, DistProfilerExtension):
"""..."""
def __init__(self, config: DictConfig, role: str, **kwargs):
...
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)
+ # __init__方法中修改
+ # 实例化PrecisionDebugger
+ # 设置环境变量DUMP_ON用于快速开关dump功能
+ dump_flag = int(os.environ.get("DUMP_ON", 0))
+ if dump_flag:
+ from msprobe.pytorch import PrecisionDebugger, seed_all
+ seed_all(mode=True)
+ self.debugger = PrecisionDebugger(task='tensor', level='L0', step=[0], dump_path='0_dump_path/')
+ self.dump_path_prefix = self.debugger.config.dump_path
+ else:
+ self.debugger = None
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
@GPUMemoryLogger(role="generate_sequences", logger=logger)
@DistProfiler.annotate(color="red")
def generate_sequences(self, prompts: DataProto):
...
with simple_timer("generate_sequences", timing_generate):
+ # generate_sequences推理处使能工具dump采集推理前向数据
+ if self.debugger:
+ infer_model = self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
+ # 推理阶段前向dump数据保存在generate_sequences文件夹
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'generate_sequences')
+ self.debugger.start(model=infer_model, token_range=[0, 0])
+ output = self.rollout.generate_sequences(prompts=prompts)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.service._reset_status()
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@GPUMemoryLogger(role="compute_log_prob", logger=logger)
@DistProfiler.annotate(color="blue")
def compute_log_prob(self, data: DataProto):
...
data.meta_info["temperature"] = self.config.rollout.temperature
+ # compute_log_prob训练处使能工具dump采集训练module级别输入输出数据
+ if self.debugger:
+ # 训练阶段dump数据保存在update_actor文件夹
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'update_actor')
+ self.debugger.start(model=self.actor.actor_module)
+ output, entropys = self.actor.compute_log_prob(data=data, calculate_entropy=True)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.step()
...