Heuristic:ContextualAI HALOs FSDP Sampling Workaround
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Debugging, LLM_Alignment |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
When training with FSDP, sampling directly from the policy model produces gibberish; instead, sync the reference model with the policy and sample from the reference.
Description
Fully Sharded Data Parallel (FSDP) distributes model parameters across GPUs, meaning no single GPU holds the complete model weights during normal operation. Attempting to call `.generate()` on an FSDP-wrapped policy model produces incoherent output because each GPU only has its shard of parameters. The workaround is to: (1) sync the reference model with the policy weights via `sync_reference_with_policy()`, (2) use `FSDP.summon_full_params()` to temporarily gather all parameters on each GPU, and (3) sample from the reference model within that context. This pattern is critical for online training workflows where the policy must generate new text.
Usage
Apply this heuristic whenever you need to generate text from the policy during FSDP training (e.g., online PPO, on-policy sampling). Never call `policy.generate()` directly. Instead, use `sync_reference_with_policy()` followed by sampling from the reference model within the `FSDP.summon_full_params()` context manager. For large-scale sampling (e.g., generating thousands of outputs for labeling), use the separate `train.sample` script with vLLM, which loads the saved checkpoint outside of FSDP.
The Insight (Rule of Thumb)
- Action: Sync reference model weights with policy, then sample from reference within `FSDP.summon_full_params()`.
- Value: Produces correct, coherent text generation during FSDP training.
- Trade-off: Requires the full model to be gathered on each GPU temporarily, which increases peak memory usage during sampling. The reference model consumes GPU memory that could otherwise be freed.
- Alternative: For offline/batch sampling, use `train.sample` with vLLM (loads model without FSDP).
Reasoning
FSDP shards model parameters across GPUs for memory efficiency. During forward/backward passes, FSDP handles gathering and scattering parameters automatically. However, the `.generate()` method performs autoregressive decoding with multiple forward passes, and FSDP's parameter gathering does not cover this use case cleanly. The `summon_full_params()` context manager forces FSDP to gather all parameters, making the model complete on each GPU for the duration of generation.
Code Evidence
Sampling workaround with docstring warning in `train/trainers.py:460-511`:
def sample(self, model, batch, temp=0.7):
"""
Sample from the given model. NOTE: If the policy is being trained with FSDP,
then sampling from it directly will produce gibberish. If you want to sample
from the policy, you should sync the reference model with the policy via
sync_reference_with_policy, then sample from reference model.
"""
# ...
if self.accelerator.state.fsdp_plugin is not None:
context = FSDP.summon_full_params(model)
else:
context = nullcontext
with context:
with torch.no_grad():
batch_elements = self.train_iterator.collate(batch_elements)
batch_completion_ids = model.generate(
batch_elements['prompt_input_ids'].to(self.accelerator.device),
attention_mask=batch_elements['prompt_attention_mask'].to(self.accelerator.device),
generation_config=generation_config
)
Reference sync method in `train/trainers.py:452-458`:
def sync_reference_with_policy(self):
"""Update the reference model to have the policy weights."""
state_dict = self.accelerator.unwrap_model(self.policy).state_dict()
self.accelerator.unwrap_model(self.reference_model).load_state_dict(state_dict)
self.accelerator.wait_for_everyone()