Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct Build Reference Logprobs Cache

From Leeroopedia


Component Type Function
Source open_instruct/dpo_utils.py (Lines 490-606)
Repository Open Instruct
Dependencies torch, torch.distributed, tqdm, open_instruct.model_utils, open_instruct.utils
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for precomputing and caching reference model log-probabilities for all training examples in a DPO pipeline, provided by the Open Instruct library.

Description

build_reference_logprobs_cache() computes log-probabilities from a frozen reference model for every example in the training set, stores them in a TensorCache, and persists the cache to disk. If a cache file already exists at the given path, it loads from disk instead of recomputing.

Key implementation details:

  • Distributed aggregation: Initializes tensors to and uses dist.all_reduce(op=MAX) to merge partial results across processes.
  • Validation: After aggregation, verifies that no indices remain at , raising a RuntimeError if any are missing.
  • LoRA support: When use_lora=True, the function disables the LoRA adapter via the provided context manager so that the base (reference) model is used.
  • Atomic writes: Saves the cache using TensorCache.to_disk(), which writes to a temporary file and renames atomically to prevent corruption.
  • Memory reporting: Logs GPU memory usage of the cache as both GiB and percentage of total device memory.
  • MFU tracking: Displays model FLOPs utilization during the caching pass via a progress bar.

Usage

Import and call build_reference_logprobs_cache() before the DPO training loop to precompute reference log-probabilities. The returned TensorCache is indexed by sample index during training.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/dpo_utils.py (Lines 490-606)

Signature

def build_reference_logprobs_cache(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    average_log_prob: bool,
    forward_fn: Callable,
    full_dataset_size: int,
    device: torch.device,
    cache_path: pathlib.Path,
    is_main_process: bool,
    model_dims: utils.ModelDims,
    use_lora: bool = False,
    disable_adapter_context: Callable[[], contextlib.AbstractContextManager] | None = None,
) -> model_utils.TensorCache:

Import

from open_instruct.dpo_utils import build_reference_logprobs_cache

I/O Contract

Inputs

Parameter Type Description
model torch.nn.Module The reference model (or policy model with LoRA adapter to be disabled) used to compute log-probabilities.
dataloader torch.utils.data.DataLoader DataLoader providing batches with an index key mapping each sample to its position in the full dataset.
average_log_prob bool Whether to average log-probabilities over sequence length (True for DPO-norm, SimPO) or sum them.
forward_fn Callable Forward function computing (chosen_logps, rejected_logps, aux_loss) from a model and batch.
full_dataset_size int Total number of examples in the training dataset (used to allocate result tensors).
device torch.device Device to place the cache tensors on.
cache_path pathlib.Path File path where the cache is saved/loaded.
is_main_process bool Whether this is the main process (rank 0); only the main process writes the cache to disk.
model_dims utils.ModelDims Model dimension info used for MFU calculation in the progress bar.
use_lora bool Whether LoRA is enabled. If True, the adapter is disabled during reference logprob computation.
disable_adapter_context Callable or None Callable returning a context manager that disables the LoRA adapter. Required when use_lora=True.

Outputs

Output Type Description
Return value model_utils.TensorCache A TensorCache containing two tensors: 'chosen_logps' and 'rejected_logps', each of shape (full_dataset_size,).

Usage Examples

from open_instruct.dpo_utils import build_reference_logprobs_cache

reference_cache = build_reference_logprobs_cache(
    model=model,
    dataloader=train_dataloader,
    average_log_prob=args.loss_type.is_average_loss,
    forward_fn=args.forward_fn,
    full_dataset_size=len(train_dataset),
    device=accelerator.device,
    cache_path=pathlib.Path("/tmp/ref_cache.pt"),
    is_main_process=accelerator.is_main_process,
    model_dims=model_dims,
    use_lora=args.use_lora,
    disable_adapter_context=model.disable_adapter,
)

# During training, retrieve cached logprobs for a batch:
ref_logps = reference_cache[batch["index"]]
# ref_logps["chosen_logps"] -> tensor of shape (batch_size,)
# ref_logps["rejected_logps"] -> tensor of shape (batch_size,)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment