Implementation:Allenai Open instruct Build Reference Logprobs Cache
| Component Type | Function |
|---|---|
| Source | open_instruct/dpo_utils.py (Lines 490-606)
|
| Repository | Open Instruct |
| Dependencies | torch, torch.distributed, tqdm, open_instruct.model_utils, open_instruct.utils |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for precomputing and caching reference model log-probabilities for all training examples in a DPO pipeline, provided by the Open Instruct library.
Description
build_reference_logprobs_cache() computes log-probabilities from a frozen reference model for every example in the training set, stores them in a TensorCache, and persists the cache to disk. If a cache file already exists at the given path, it loads from disk instead of recomputing.
Key implementation details:
- Distributed aggregation: Initializes tensors to and uses
dist.all_reduce(op=MAX)to merge partial results across processes. - Validation: After aggregation, verifies that no indices remain at , raising a
RuntimeErrorif any are missing. - LoRA support: When
use_lora=True, the function disables the LoRA adapter via the provided context manager so that the base (reference) model is used. - Atomic writes: Saves the cache using
TensorCache.to_disk(), which writes to a temporary file and renames atomically to prevent corruption. - Memory reporting: Logs GPU memory usage of the cache as both GiB and percentage of total device memory.
- MFU tracking: Displays model FLOPs utilization during the caching pass via a progress bar.
Usage
Import and call build_reference_logprobs_cache() before the DPO training loop to precompute reference log-probabilities. The returned TensorCache is indexed by sample index during training.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/dpo_utils.py(Lines 490-606)
Signature
def build_reference_logprobs_cache(
model: torch.nn.Module,
dataloader: torch.utils.data.DataLoader,
average_log_prob: bool,
forward_fn: Callable,
full_dataset_size: int,
device: torch.device,
cache_path: pathlib.Path,
is_main_process: bool,
model_dims: utils.ModelDims,
use_lora: bool = False,
disable_adapter_context: Callable[[], contextlib.AbstractContextManager] | None = None,
) -> model_utils.TensorCache:
Import
from open_instruct.dpo_utils import build_reference_logprobs_cache
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
model |
torch.nn.Module |
The reference model (or policy model with LoRA adapter to be disabled) used to compute log-probabilities. |
dataloader |
torch.utils.data.DataLoader |
DataLoader providing batches with an index key mapping each sample to its position in the full dataset.
|
average_log_prob |
bool |
Whether to average log-probabilities over sequence length (True for DPO-norm, SimPO) or sum them. |
forward_fn |
Callable |
Forward function computing (chosen_logps, rejected_logps, aux_loss) from a model and batch.
|
full_dataset_size |
int |
Total number of examples in the training dataset (used to allocate result tensors). |
device |
torch.device |
Device to place the cache tensors on. |
cache_path |
pathlib.Path |
File path where the cache is saved/loaded. |
is_main_process |
bool |
Whether this is the main process (rank 0); only the main process writes the cache to disk. |
model_dims |
utils.ModelDims |
Model dimension info used for MFU calculation in the progress bar. |
use_lora |
bool |
Whether LoRA is enabled. If True, the adapter is disabled during reference logprob computation. |
disable_adapter_context |
Callable or None |
Callable returning a context manager that disables the LoRA adapter. Required when use_lora=True.
|
Outputs
| Output | Type | Description |
|---|---|---|
| Return value | model_utils.TensorCache |
A TensorCache containing two tensors: 'chosen_logps' and 'rejected_logps', each of shape (full_dataset_size,).
|
Usage Examples
from open_instruct.dpo_utils import build_reference_logprobs_cache
reference_cache = build_reference_logprobs_cache(
model=model,
dataloader=train_dataloader,
average_log_prob=args.loss_type.is_average_loss,
forward_fn=args.forward_fn,
full_dataset_size=len(train_dataset),
device=accelerator.device,
cache_path=pathlib.Path("/tmp/ref_cache.pt"),
is_main_process=accelerator.is_main_process,
model_dims=model_dims,
use_lora=args.use_lora,
disable_adapter_context=model.disable_adapter,
)
# During training, retrieve cached logprobs for a batch:
ref_logps = reference_cache[batch["index"]]
# ref_logps["chosen_logps"] -> tensor of shape (batch_size,)
# ref_logps["rejected_logps"] -> tensor of shape (batch_size,)