fsdp训练后端verl训推一致性比对数据采集
简介
verl训推一致性比对场景在比对前,需要保证训练和推理时的输入 shape 一致,才能确保比对时训练和推理dump的精度数据匹配。
verl训练推理输入对齐分析
一般情况下,训练和推理时的输入 shape 不一致,可从训练推理运行原理分析。
-
推理运行分为 prefill 和 k 个 decode 两个步骤:
- 在 prefill 步骤时,推理输入为 prompt。
- 在 k 个 decode 步骤时,kv cache 加上上一个 decode 得到的输出 token ,最终输出推理的 response。
-
训练运行时,输入为 prompt 加上推理输出的 response,最终输出 logits。
综合以上信息,推理输入为 prompt;训练输入为 prompt 加上推理输出的 response。
结论:需要将训练的输入调整为单 prompt,与推理的输入保持一致。
前置操作
要保证训练 forward 和推理 prefill 的 shape 一致,需要去掉训练输入中的 response,首先需要满足如下前提,并修改训练脚本:
-
保证训练中的batch size维度未被拆分。
-
需保证每轮训练中用于梯度更新的mini batch个数mini_batch_num = 1
计算公式为:mini_batch_num = train_batch_size / train_ppo_mini_batch_size
- train_batch_size: 训练中总的样本数。
- train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
-
需保证梯度累计步骤数gac (Gradient Accumulation Steps) = 1
计算公式为:gac = train_ppo_mini_batch_size * n_resp_per_prompt / train_ppo_micro_batch_size_per_gpu / DP
- train_ppo_mini_batch_size: 每个 mini batch 的样本数量。
- n_resp_per_prompt: 每个提示(prompt)下的响应数。
- train_ppo_micro_batch_size_per_gpu: 每个GPU上处理的 micro batch 大小。
- DP: 数据并行度。
上述参数在脚本中具体修改为:
data.train_batch_size=${train_batch_size} actor_rollout_ref.actor.ppo_mini_batch_size=${train_ppo_mini_batch_size} actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_ppo_micro_batch_size_per_gpu} actor_rollout_ref.rollout.n=${n_resp_per_prompt} -
-
保证训练中无pad。
actor_rollout_ref.model.use_remove_padding=True actor_rollout_ref.actor.use_dynamic_bsz=False -
在训练脚本中修改环境变量。
export DUMP_ON=1 export PROMPTS_ONLY=1 export TORCHDYNAMO_DISABLE=1 -
保证训练和推理采集的数据在每张卡上是一一对应的。 balance_batch: 自动平衡、均分batch数据
trainer.balance_batch=False
verl代码修改
去掉训练输入中的 response,需要修改 verl/workers/actor/dp_actor.py,以 release/v0.6.1 为例,修改处高亮显示如下:
...
...
def _forward_micro_batch(
self, micro_batch, temperature, calculate_entropy=False
) -> tuple[torch.Tensor, torch.Tensor]:
"""..."""
+ # _forward_micro_batch方法中修改
- response_length = micro_batch["responses"].size(-1)
+ if "responses" in micro_batch and micro_batch["responses"] is not None:
+ response_length = micro_batch["responses"].size(-1)
+ else:
+ response_length = 0
multi_modal_inputs = {}
...
@GPUMemoryLogger(role="dp actor", logger=logger)
def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Tensor:
"""..."""
# set to eval
self.actor_module.eval()
+ # compute_log_prob方法中修改
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ if "responses" in data.batch:
+ responses_len = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+ if data.batch["position_ids"].dim() == 3:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+ # remove responses from batch
+ data.batch["responses"] = None
+ if "rollout_log_probs" in data.batch:
+ data.batch["rollout_log_probs"] = None
+ if "response_mask" in data.batch:
+ data.batch["response_mask"] = None
+
micro_batch_size = data.meta_info["micro_batch_size"]
...
@GPUMemoryLogger(role="dp actor", logger=logger)
def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
+ # update_policy方法中修改
+ compute_prompts_only = int(os.getenv("PROMPTS_ONLY", "0"))
+ if compute_prompts_only:
+ if "responses" in data.batch:
+ responses_len = data.batch["responses"].size(1)
+ data.batch["input_ids"] = data.batch["input_ids"][:, :-responses_len]
+ data.batch["attention_mask"] = data.batch["attention_mask"][:, :-responses_len]
+ if data.batch["position_ids"].dim() == 3:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :, :-responses_len]
+ else:
+ data.batch["position_ids"] = data.batch["position_ids"][:, :-responses_len]
+ # remove responses from batch
+ data.batch["responses"] = None
+ if "rollout_log_probs" in data.batch:
+ data.batch["rollout_log_probs"] = None
+ if "response_mask" in data.batch:
+ data.batch["response_mask"] = None
+
select_keys = [
"responses",
"response_mask",
"input_ids",
"attention_mask",
"position_ids",
"old_log_probs",
"advantages",
]
...
...
# Extract pre-computed rollout correction weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
+ # update_policy方法中修改
+ if response_mask is None:
+ prompt_mask = torch.ones_like(log_prob, dtype=torch.bool)
+ response_mask = prompt_mask
+
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Compute policy loss (any function is expected to return 2 values)
pg_loss, pg_metrics = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_is_weights=rollout_is_weights,
)
micro_batch_metrics.update(pg_metrics)
...
数据采集
在 verl/workers/fsdp_workers.py 中添加msProbe工具的PrecisionDebugger接口进行dump操作。PrecisionDebugger接口更多介绍请参见《PyTorch场景精度数据采集》。
修改示例代码高亮显示如下:
...
class ActorRolloutRefWorker(Worker, DistProfilerExtension):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str, **kwargs):
...
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (
self.device_mesh.size() // self.ulysses_sequence_parallel_size
)
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= self.device_mesh.size() // self.ulysses_sequence_parallel_size
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
+ # __init__方法中修改
+ # 实例化PrecisionDebugger
+ dump_flag = int(os.environ.get("DUMP_ON", 0)) # 设置环境变量DUMP_ON用于快速开关dump功能
+ if dump_flag:
+ from msprobe.pytorch import PrecisionDebugger, seed_all
+ seed_all(mode=True)
+ self.debugger = PrecisionDebugger(task='tensor', level='L0', dump_path='example_dump_path', step=[0])
+ self.dump_path_prefix = self.debugger.config.dump_path
+ else:
+ self.debugger = None
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
...
with self.ulysses_sharding_manager:
data = data.to("cpu") # data will to device with each micro batch on actor.update_policy
+ # update_actor方法中修改
+ if self.debugger:
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'update_actor') # 训练结果保存在update_actor文件夹
+ self.debugger.start(model=self.actor.actor_module)
# perform training
with Timer(name="update_policy", logger=None) as timer:
metrics = self.actor.update_policy(data=data)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.step()
delta_time = timer.last
global_num_tokens = data.meta_info["global_token_num"]
...
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"))
@DistProfiler.annotate(color="red", role="rollout_generate")
def generate_sequences(self, prompts: DataProto):
...
with simple_timer("generate_sequences", timing_generate):
+ # generate_sequences方法中修改
+ if self.debugger:
+ self.debugger.service.config.dump_path = os.path.join(self.dump_path_prefix, 'generate_sequences') # 推理结果保存在generate_sequences文件夹
+ infer_model = self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
+ self.debugger.start(model=infer_model, token_range=[0, 0])
output = self.rollout.generate_sequences(prompts=prompts)
+ if self.debugger:
+ self.debugger.stop()
+ self.debugger.service._reset_status()
if self._is_actor:
loop.run_until_complete(self.trainer_mode())
log_gpu_memory_usage("After switch to trainer mode", logger=logger)
...