fsdp训练后端verl训推一致性比对数据采集

简介

verl训推一致性比对场景在比对前,需要保证训练和推理时的输入 shape 一致,才能确保比对时训练和推理dump的精度数据匹配。

verl训练推理输入对齐分析

一般情况下,训练和推理时的输入 shape 不一致,可从训练推理运行原理分析。

  • 推理运行分为 prefill 和 k 个 decode 两个步骤:

    1. 在 prefill 步骤时,推理输入为 prompt。
    2. 在 k 个 decode 步骤时,kv cache 加上上一个 decode 得到的输出 token ,最终输出推理的 response。
  • 训练运行时,输入为 prompt 加上推理输出的 response,最终输出 logits。

综合以上信息,推理输入为 prompt;训练输入为 prompt 加上推理输出的 response。

结论:需要将训练的输入调整为单 prompt,与推理的输入保持一致。

前置操作

要保证训练 forward 和推理 prefill 的 shape 一致,需要去掉训练输入中的 response,首先需要满足如下前提,并修改训练脚本:

  1. 保证训练中的batch size维度未被拆分。

    1. 需保证每轮训练中用于梯度更新的mini batch个数mini_batch_num = 1

      计算公式为:mini_batch_num = train_batch_size / train_ppo_mini_batch_size

      • train_batch_size: 训练中总的样本数。
      • train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
    2. 需保证梯度累计步骤数gac (Gradient Accumulation Steps) = 1

      计算公式为:gac = train_ppo_mini_batch_size * n_resp_per_prompt / train_ppo_micro_batch_size_per_gpu / DP

      • train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
      • n_resp_per_prompt: 每个提示(prompt)下的响应数。
      • train_ppo_micro_batch_size_per_gpu: 每个GPU上处理的 micro batch 大小。
      • DP: 数据并行度。

    上述参数在脚本中具体修改为:

    data.train_batch_size=${train_batch_size}
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_ppo_mini_batch_size}
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu}
    actor_rollout_ref.rollout.n=${n_resp_per_prompt}
    
  2. 保证训练中无pad。

    actor_rollout_ref.model.use_remove_padding=True
    actor_rollout_ref.actor.use_dynamic_bsz=False
    
  3. 在训练脚本中修改环境变量。

    export DUMP_ON=1
    export PROMPTS_ONLY=1
    export TORCHDYNAMO_DISABLE=1
    
  4. 保证训练和推理采集的数据在每张卡上是一一对应的。 balance_batch: 自动平衡、均分batch数据

    trainer.balance_batch=False
    

verl代码修改

去掉训练输入中的 response,需要修改 verl/workers/actor/dp_actor.py,以 release/v0.6.1 为例,修改处高亮显示如下:

 ...
     ...
     def _forward_micro_batch(
         self, micro_batch, temperature, calculate_entropy=False
     ) -> tuple[torch.Tensor, torch.Tensor]:
         """..."""

+        # _forward_micro_batch方法中修改
-        response_length = micro_batch["responses"].size(-1)
+        if "responses" in micro_batch and micro_batch["responses"] is not None:
+            response_length = micro_batch["responses"].size(-1)
+        else:
+            response_length = 0
         multi_modal_inputs = {}
         ...
 
     @GPUMemoryLogger(role="dp actor", logger=logger)
     def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
         """..."""
         # set to eval
         self.actor_module.eval()

+        # compute_log_prob方法中修改
+        compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+        if compute_prompts_only:
+            if "responses" in data.batch:
+                responses_len = data.batch["responses"].size(1)
+                data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+                data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+                if data.batch["position_ids"].dim() == 3:
+                    data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+                else:
+                    data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+                # remove responses from batch
+                data.batch["responses"] = None
+                if "rollout_log_probs" in data.batch:
+                    data.batch["rollout_log_probs"] = None
+                if "response_mask" in data.batch:
+                    data.batch["response_mask"] = None         
+ 
         
         micro_batch_size = data.meta_info["micro_batch_size"]
         ...
 
     @GPUMemoryLogger(role="dp actor", logger=logger)
     def update_policy(self, data: DataProto):
         # make sure we are in training mode
         self.actor_module.train()
 
         temperature = data.meta_info["temperature"]  # temperature must be in the data.meta_info to avoid silent error

+        # update_policy方法中修改
+        compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+        if compute_prompts_only:
+            if "responses" in data.batch:
+                responses_len = data.batch["responses"].size(1)
+                data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+                data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+                if data.batch["position_ids"].dim() == 3:
+                    data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+                else:
+                    data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+                # remove responses from batch
+                data.batch["responses"] = None
+                if "rollout_log_probs" in data.batch:
+                    data.batch["rollout_log_probs"] = None
+                if "response_mask" in data.batch:
+                    data.batch["response_mask"] = None
+ 
         select_keys = [
             "responses",
             "response_mask",
             "input_ids",
             "attention_mask",
             "position_ids",
             "old_log_probs",
             "advantages",
         ]
         ...
                     ...
                     # Extract pre-computed rollout correction weights if present
                     # Weights are computed centrally in trainer and added when algorithm.rollout_is=True
                     rollout_is_weights = model_inputs.get("rollout_is_weights", None)

+                    # update_policy方法中修改                     
+                    if response_mask is None:
+                        prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+                        response_mask = prompt_mask
+ 
                     # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
                     # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
                     policy_loss_fn = get_policy_loss_fn(loss_mode)
 
                     # Compute policy loss (any function is expected to return 2 values)
                     pg_loss, pg_metrics = policy_loss_fn(
                         old_log_prob=old_log_prob,
                         log_prob=log_prob,
                         advantages=advantages,
                         response_mask=response_mask,
                         loss_agg_mode=loss_agg_mode,
                         config=self.config,
                         rollout_is_weights=rollout_is_weights,
                     )
                     micro_batch_metrics.update(pg_metrics)
                     ...

数据采集

在 verl/workers/fsdp_workers.py 中添加msProbe工具的PrecisionDebugger接口进行dump操作。PrecisionDebugger接口更多介绍请参见《PyTorch场景精度数据采集》。

修改示例代码高亮显示如下:

 ...
 class ActorRolloutRefWorker(Worker, DistProfilerExtension):
     """
     This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
     or a hybrid engine based on the config.rollout
     """
 
     def __init__(self, config: DictConfig, role: str, **kwargs):
         ...
         # normalize rollout config
         if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
             self.config.rollout.log_prob_micro_batch_size //= (
                 self.device_mesh.size() // self.ulysses_sequence_parallel_size
             )
             self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
         # normalize ref config
         if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
             self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
             self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size

+        # __init__方法中修改
+        # 实例化PrecisionDebugger
+        dump_flag = int(os.environ.get("DUMP_ON", 0))  # 设置环境变量DUMP_ON用于快速开关dump功能
+        if dump_flag:
+            from msprobe.pytorch import PrecisionDebugger, seed_all
+            seed_all(mode=True)
+            self.debugger = PrecisionDebugger(task='tensor', level='L0', dump_path='example_dump_path', step=[0])
+            self.dump_path_prefix = self.debugger.config.dump_path
+        else:
+            self.debugger = None
 
     ...
     @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
     @DistProfiler.annotate(color="red", role="actor_update")
     def update_actor(self, data: DataProto):
         ...
         with self.ulysses_sharding_manager:
             data = data.to("cpu")  # data will to device with each micro batch on actor.update_policy
+            # update_actor方法中修改
+            if self.debugger:
+                self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'update_actor')  # 训练结果保存在update_actor文件夹
+                self.debugger.start(model=self.actor.actor_module)
             # perform training
             with Timer(name="update_policy", logger=None) as timer:
                 metrics = self.actor.update_policy(data=data)
+            if self.debugger:
+                self.debugger.stop()
+                self.debugger.step()
             delta_time = timer.last
             global_num_tokens = data.meta_info["global_token_num"]
             ...
     
     @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
     @DistProfiler.annotate(color="red", role="rollout_generate")
     def generate_sequences(self, prompts: DataProto):
         ...
         with simple_timer("generate_sequences", timing_generate):
+            # generate_sequences方法中修改
+            if self.debugger:
+                self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'generate_sequences')  # 推理结果保存在generate_sequences文件夹
+                infer_model = self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
+                self.debugger.start(model=infer_model, token_range=[0, 0])
             output = self.rollout.generate_sequences(prompts=prompts)
+            if self.debugger:
+                self.debugger.stop()
+                self.debugger.service._reset_status()
 
         if self._is_actor:
             loop.run_until_complete(self.trainer_mode())
             log_gpu_memory_usage("After switch to trainer mode", logger=logger)
         ...