Principle:Allenai Open instruct Reference Logprob Caching
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Optimization, Distributed Computing |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Reference logprob caching is the technique of computing the reference model's log-probabilities for all training examples once before training begins, persisting them to disk, and reusing the cached values throughout DPO training to eliminate redundant forward passes through the reference model.
Description
The DPO loss function requires log-probabilities from two models: the policy model (which is being trained) and the reference model (which remains frozen). Naively, each training step would require a forward pass through both models for every batch, effectively doubling the compute cost.
Reference logprob caching exploits the fact that is frozen throughout training. Since the reference model never changes, its log-probabilities for each training example are constant. By computing these values once in a dedicated pre-training pass and caching them, the training loop only needs to run forward passes through the policy model.
The caching process works as follows:
- Initialization: Check whether a cache file already exists at the configured path. If so, load it directly and skip computation.
- Computation: Iterate over the full training dataloader with the reference model in eval mode and
torch.no_grad(). For each batch, compute chosen and rejected log-probabilities using the configured forward function. - Aggregation: In distributed settings, each process computes log-probabilities for its data shard. After the pass, an
all_reducewithMAXoperation merges all shards (uncomputed entries are initialized to so thatMAXselects the computed value). - Validation: Verify that no entries remain at , confirming full dataset coverage.
- Persistence: The main process saves the cache to disk. A distributed barrier ensures all processes wait until the file is written.
- Reuse: During training, the cache is indexed by batch sample indices to retrieve precomputed reference log-probabilities with negligible overhead.
Usage
Use reference logprob caching when:
- Training with any DPO loss variant that requires a reference model (standard DPO, DPO-norm, WPO).
- The reference model is frozen and not updated during training.
- Training data is static (not changing between epochs).
- You want to reduce GPU memory and compute by avoiding a second model in the training loop.
Reference logprob caching is not needed for SimPO, which does not use a reference model.
Theoretical Basis
The DPO loss is defined as:
The terms and are constants with respect to . The caching strategy precomputes these values:
for each training example . During training, these cached values are retrieved by index, reducing the per-step cost from two model forward passes to one.
Pseudocode:
# Pre-training cache construction
chosen_logps = tensor of shape [N], initialized to -inf
rejected_logps = tensor of shape [N], initialized to -inf
for batch in dataloader:
with no_grad():
c_logps, r_logps = reference_model.forward(batch)
chosen_logps[batch.indices] = c_logps
rejected_logps[batch.indices] = r_logps
all_reduce(chosen_logps, op=MAX) # merge across distributed processes
all_reduce(rejected_logps, op=MAX)
assert no entries remain -inf
save_to_disk(chosen_logps, rejected_logps)