Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Allenai Open instruct Reference Logprob Caching

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Optimization, Distributed Computing
Last Updated 2026-02-07 00:00 GMT

Overview

Reference logprob caching is the technique of computing the reference model's log-probabilities for all training examples once before training begins, persisting them to disk, and reusing the cached values throughout DPO training to eliminate redundant forward passes through the reference model.

Description

The DPO loss function requires log-probabilities from two models: the policy model πθ (which is being trained) and the reference model πref (which remains frozen). Naively, each training step would require a forward pass through both models for every batch, effectively doubling the compute cost.

Reference logprob caching exploits the fact that πref is frozen throughout training. Since the reference model never changes, its log-probabilities for each training example are constant. By computing these values once in a dedicated pre-training pass and caching them, the training loop only needs to run forward passes through the policy model.

The caching process works as follows:

  1. Initialization: Check whether a cache file already exists at the configured path. If so, load it directly and skip computation.
  2. Computation: Iterate over the full training dataloader with the reference model in eval mode and torch.no_grad(). For each batch, compute chosen and rejected log-probabilities using the configured forward function.
  3. Aggregation: In distributed settings, each process computes log-probabilities for its data shard. After the pass, an all_reduce with MAX operation merges all shards (uncomputed entries are initialized to so that MAX selects the computed value).
  4. Validation: Verify that no entries remain at , confirming full dataset coverage.
  5. Persistence: The main process saves the cache to disk. A distributed barrier ensures all processes wait until the file is written.
  6. Reuse: During training, the cache is indexed by batch sample indices to retrieve precomputed reference log-probabilities with negligible overhead.

Usage

Use reference logprob caching when:

  • Training with any DPO loss variant that requires a reference model (standard DPO, DPO-norm, WPO).
  • The reference model is frozen and not updated during training.
  • Training data is static (not changing between epochs).
  • You want to reduce GPU memory and compute by avoiding a second model in the training loop.

Reference logprob caching is not needed for SimPO, which does not use a reference model.

Theoretical Basis

The DPO loss is defined as:

DPO(πθ;πref)=𝔼(x,yw,yl)𝒟[logσ(β(logπθ(yw|x)πref(yw|x)logπθ(yl|x)πref(yl|x)))]

The terms logπref(yw|x) and logπref(yl|x) are constants with respect to θ. The caching strategy precomputes these values:

cache[i]=(logπref(yw(i)|x(i)),logπref(yl(i)|x(i)))

for each training example i. During training, these cached values are retrieved by index, reducing the per-step cost from two model forward passes to one.

Pseudocode:

# Pre-training cache construction
chosen_logps = tensor of shape [N], initialized to -inf
rejected_logps = tensor of shape [N], initialized to -inf

for batch in dataloader:
    with no_grad():
        c_logps, r_logps = reference_model.forward(batch)
    chosen_logps[batch.indices] = c_logps
    rejected_logps[batch.indices] = r_logps

all_reduce(chosen_logps, op=MAX)   # merge across distributed processes
all_reduce(rejected_logps, op=MAX)
assert no entries remain -inf

save_to_disk(chosen_logps, rejected_logps)

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment