Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:ContextualAI HALOs Reference Logprob Caching

From Leeroopedia




Knowledge Sources
Domains Optimization, Memory_Management, LLM_Alignment
Last Updated 2026-02-08 03:00 GMT

Overview

Memory optimization that precomputes reference model log probabilities and deletes the reference model from GPU, substantially reducing VRAM usage during alignment training.

Description

The `ReferenceModelWrapper` precomputes the token-level log probabilities for all training and evaluation examples using the reference model in a single pass before training begins. These log probabilities are stored in a Python dictionary keyed by token ID tuples. After caching is complete, the reference model and its dedicated Accelerator are deleted from GPU memory, freeing substantial VRAM. During training, the wrapper returns cached log probabilities via dictionary lookup instead of running a forward pass. Cached logprobs can also be serialized to a pickle file and reloaded across jobs sharing the same reference model.

Usage

Use this heuristic when you are VRAM constrained and the reference model does not need to change during training (i.e., `humanline=false` and `sync_reference=false`). Enable by setting `++cache_reference_logprobs=true`. To reuse cached logprobs across jobs, set `++load_reference_logprobs=PATH` to a previously saved pickle file. Do not use this with humanline alignment or sync_reference, as those require the reference model to be updated during training.

The Insight (Rule of Thumb)

  • Action: Set `++cache_reference_logprobs=true` in the launch command.
  • Value: Frees ~50% of GPU memory (the entire reference model) after the caching pass.
  • Trade-off: Adds upfront compute time for the caching pass before training begins. The cached data is stored in CPU RAM (as a Python dict), so sufficient CPU memory is needed.
  • Reusability: Cached logprobs are saved to `cached_reference_logprobs.pkl` in the run directory. Reuse with `++load_reference_logprobs=<path>` for subsequent jobs using the same reference model and data.

Reasoning

In standard alignment training (DPO, KTO, etc.), the reference model runs a forward pass on every batch to compute log probabilities, consuming GPU memory equal to the full model. Since the reference model is frozen (when not using humanline/sync_reference), its outputs are deterministic and can be precomputed once. The `ReferenceModelWrapper` exploits this by caching all logprobs upfront, then deleting the model. This is especially valuable when running multiple alignment experiments with the same SFT checkpoint as the reference, as the pickle file can be shared across runs.

Code Evidence

ReferenceModelWrapper precompute and free pattern from `train/models.py:425-536`:

class ReferenceModelWrapper(nn.Module):
    """
    A wrapper around the reference model that precomputes the logprobs and saves them
    in a local dict, after which the reference model and accelerator are deleted to
    save GPU memory.
    """
    def __init__(self, reference_accelerator, reference_model, tokenizer, config, iterators):
        super().__init__()
        # ... setup ...
        if config.load_reference_logprobs:
            self.logprobs = pickle.load(open(config.load_reference_logprobs, 'rb'))
        else:
            self._precompute_log_probs()
        self._free_memory()  # delete the reference model and accelerator

Memory cleanup in `train/models.py:533-536`:

def _free_memory(self):
    del self.reference_accelerator
    del self.reference_model
    torch.cuda.empty_cache()

Forward pass returns cached values from `train/models.py:538-543`:

def forward(self, input_ids, *args, **kwargs):
    batch_logprobs = [torch.Tensor(self.logprobs[tuple(self._remove_padding(k))])
                      for k in input_ids.tolist()]
    batch_logprobs = pad_sequence(batch_logprobs, batch_first=True, padding_value=0)
    return batch_logprobs

Config options in `config/config.yaml:88-92`:

cache_reference_logprobs: false
load_reference_logprobs: null

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment