Heuristic:Eric mitchell Direct preference optimization Disable Sampling During Eval
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Debugging |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Disable text sampling during evaluation when using FSDPTrainer or TensorParallelTrainer to avoid extremely slow generation.
Description
Autoregressive text generation (sampling) is extremely slow when models are sharded across GPUs with FSDP or TensorParallel. For FSDP, each token generation step requires gathering the full model parameters, negating the sharding benefit. For TensorParallel, the sequential nature of generation combined with cross-GPU communication creates severe bottlenecks. The DPO codebase addresses this by providing a `sample_during_eval` flag that controls whether text samples are generated during evaluation checkpoints.
Usage
Use this heuristic always when training with FSDPTrainer or TensorParallelTrainer. Pass `sample_during_eval=false` on the command line. For BasicTrainer (single-GPU or naive multi-GPU), sampling during eval is fine and provides useful qualitative feedback.
The Insight (Rule of Thumb)
- Action: Pass `sample_during_eval=false` when using FSDPTrainer or TensorParallelTrainer.
- Value: Boolean flag, default `true` in config.
- Trade-off: Lose qualitative text generation samples during evaluation in exchange for dramatically faster evaluation. Quantitative metrics (loss, reward accuracy, log probabilities) are still computed.
- Compatibility: Only needed for FSDPTrainer and TensorParallelTrainer. BasicTrainer can safely use `sample_during_eval=true`.
Reasoning
The README explicitly warns: "Sampling may be very slow for `FSDPTrainer` and especially `TensorParallelTrainer`." This is because:
FSDP: Generation requires `FSDP.summon_full_params()` context manager to gather all sharded parameters onto each GPU, effectively undoing the memory benefit of sharding during generation. Code from `trainers.py:186-189`:
# FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
with ctx():
policy_output = self.policy.generate(
batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)
TensorParallel: The `TensorParallelTrainer` docstring notes (trainers.py:528-529): "Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow, see https://github.com/BlackSamorez/tensor_parallel/issues/66."
The config comment also documents this in `config/config.yaml:37-39`:
# whether or not to generate samples during evaluation; disable for FSDP/TensorParallel
# is recommended, because they are slow
sample_during_eval: true