Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:ContextualAI HALOs FSDP Sampling Workaround

From Leeroopedia



Knowledge Sources
Domains Distributed_Training, Debugging, LLM_Alignment
Last Updated 2026-02-08 03:00 GMT

Overview

When training with FSDP, sampling directly from the policy model produces gibberish; instead, sync the reference model with the policy and sample from the reference.

Description

Fully Sharded Data Parallel (FSDP) distributes model parameters across GPUs, meaning no single GPU holds the complete model weights during normal operation. Attempting to call `.generate()` on an FSDP-wrapped policy model produces incoherent output because each GPU only has its shard of parameters. The workaround is to: (1) sync the reference model with the policy weights via `sync_reference_with_policy()`, (2) use `FSDP.summon_full_params()` to temporarily gather all parameters on each GPU, and (3) sample from the reference model within that context. This pattern is critical for online training workflows where the policy must generate new text.

Usage

Apply this heuristic whenever you need to generate text from the policy during FSDP training (e.g., online PPO, on-policy sampling). Never call `policy.generate()` directly. Instead, use `sync_reference_with_policy()` followed by sampling from the reference model within the `FSDP.summon_full_params()` context manager. For large-scale sampling (e.g., generating thousands of outputs for labeling), use the separate `train.sample` script with vLLM, which loads the saved checkpoint outside of FSDP.

The Insight (Rule of Thumb)

  • Action: Sync reference model weights with policy, then sample from reference within `FSDP.summon_full_params()`.
  • Value: Produces correct, coherent text generation during FSDP training.
  • Trade-off: Requires the full model to be gathered on each GPU temporarily, which increases peak memory usage during sampling. The reference model consumes GPU memory that could otherwise be freed.
  • Alternative: For offline/batch sampling, use `train.sample` with vLLM (loads model without FSDP).

Reasoning

FSDP shards model parameters across GPUs for memory efficiency. During forward/backward passes, FSDP handles gathering and scattering parameters automatically. However, the `.generate()` method performs autoregressive decoding with multiple forward passes, and FSDP's parameter gathering does not cover this use case cleanly. The `summon_full_params()` context manager forces FSDP to gather all parameters, making the model complete on each GPU for the duration of generation.

Code Evidence

Sampling workaround with docstring warning in `train/trainers.py:460-511`:

def sample(self, model, batch, temp=0.7):
    """
    Sample from the given model. NOTE: If the policy is being trained with FSDP,
    then sampling from it directly will produce gibberish. If you want to sample
    from the policy, you should sync the reference model with the policy via
    sync_reference_with_policy, then sample from reference model.
    """
    # ...
    if self.accelerator.state.fsdp_plugin is not None:
        context = FSDP.summon_full_params(model)
    else:
        context = nullcontext

    with context:
        with torch.no_grad():
            batch_elements = self.train_iterator.collate(batch_elements)
            batch_completion_ids = model.generate(
                batch_elements['prompt_input_ids'].to(self.accelerator.device),
                attention_mask=batch_elements['prompt_attention_mask'].to(self.accelerator.device),
                generation_config=generation_config
            )

Reference sync method in `train/trainers.py:452-458`:

def sync_reference_with_policy(self):
    """Update the reference model to have the policy weights."""
    state_dict = self.accelerator.unwrap_model(self.policy).state_dict()
    self.accelerator.unwrap_model(self.reference_model).load_state_dict(state_dict)
    self.accelerator.wait_for_everyone()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment