Heuristic:ContextualAI HALOs Reference Logprob Caching
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management, LLM_Alignment |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
Memory optimization that precomputes reference model log probabilities and deletes the reference model from GPU, substantially reducing VRAM usage during alignment training.
Description
The `ReferenceModelWrapper` precomputes the token-level log probabilities for all training and evaluation examples using the reference model in a single pass before training begins. These log probabilities are stored in a Python dictionary keyed by token ID tuples. After caching is complete, the reference model and its dedicated Accelerator are deleted from GPU memory, freeing substantial VRAM. During training, the wrapper returns cached log probabilities via dictionary lookup instead of running a forward pass. Cached logprobs can also be serialized to a pickle file and reloaded across jobs sharing the same reference model.
Usage
Use this heuristic when you are VRAM constrained and the reference model does not need to change during training (i.e., `humanline=false` and `sync_reference=false`). Enable by setting `++cache_reference_logprobs=true`. To reuse cached logprobs across jobs, set `++load_reference_logprobs=PATH` to a previously saved pickle file. Do not use this with humanline alignment or sync_reference, as those require the reference model to be updated during training.
The Insight (Rule of Thumb)
- Action: Set `++cache_reference_logprobs=true` in the launch command.
- Value: Frees ~50% of GPU memory (the entire reference model) after the caching pass.
- Trade-off: Adds upfront compute time for the caching pass before training begins. The cached data is stored in CPU RAM (as a Python dict), so sufficient CPU memory is needed.
- Reusability: Cached logprobs are saved to `cached_reference_logprobs.pkl` in the run directory. Reuse with `++load_reference_logprobs=<path>` for subsequent jobs using the same reference model and data.
Reasoning
In standard alignment training (DPO, KTO, etc.), the reference model runs a forward pass on every batch to compute log probabilities, consuming GPU memory equal to the full model. Since the reference model is frozen (when not using humanline/sync_reference), its outputs are deterministic and can be precomputed once. The `ReferenceModelWrapper` exploits this by caching all logprobs upfront, then deleting the model. This is especially valuable when running multiple alignment experiments with the same SFT checkpoint as the reference, as the pickle file can be shared across runs.
Code Evidence
ReferenceModelWrapper precompute and free pattern from `train/models.py:425-536`:
class ReferenceModelWrapper(nn.Module):
"""
A wrapper around the reference model that precomputes the logprobs and saves them
in a local dict, after which the reference model and accelerator are deleted to
save GPU memory.
"""
def __init__(self, reference_accelerator, reference_model, tokenizer, config, iterators):
super().__init__()
# ... setup ...
if config.load_reference_logprobs:
self.logprobs = pickle.load(open(config.load_reference_logprobs, 'rb'))
else:
self._precompute_log_probs()
self._free_memory() # delete the reference model and accelerator
Memory cleanup in `train/models.py:533-536`:
def _free_memory(self):
del self.reference_accelerator
del self.reference_model
torch.cuda.empty_cache()
Forward pass returns cached values from `train/models.py:538-543`:
def forward(self, input_ids, *args, **kwargs):
batch_logprobs = [torch.Tensor(self.logprobs[tuple(self._remove_padding(k))])
for k in input_ids.tolist()]
batch_logprobs = pad_sequence(batch_logprobs, batch_first=True, padding_value=0)
return batch_logprobs
Config options in `config/config.yaml:88-92`:
cache_reference_logprobs: false
load_reference_logprobs: null