Implementation:Allenai Open instruct TensorCache
Appearance
| Component Type | Dataclass |
|---|---|
| Source | open_instruct/model_utils.py (Lines 48-72)
|
| Repository | Open Instruct |
| Dependencies | torch, pathlib, tempfile |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for storing, indexing, and persisting named tensors to disk, provided by the Open Instruct library.
Description
TensorCache is a Python dataclass that wraps a dictionary of named torch.Tensor objects and provides three core capabilities:
- Index-based retrieval (
__getitem__): Accepts atorch.Tensorof indices and returns a dictionary of sliced tensors, selecting the same indices from every stored tensor. This enables batch-level lookups during training.
- Atomic disk persistence (
to_disk): Saves all tensors to a file usingtorch.save(). The write is atomic: data is first written to a temporary file in the same directory, then renamed to the target path. Tensors are moved to CPU before saving to avoid device-specific serialization issues.
- Disk loading (
from_disk): A classmethod that loads a previously saved cache from disk, placing the tensors on the specified device. Usesweights_only=Truefor safe deserialization.
Usage
Import TensorCache when you need to cache precomputed tensor values (such as reference model log-probabilities) for reuse during a training loop, or when persisting such caches across training runs.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/model_utils.py(Lines 48-72)
Signature
@dataclass
class TensorCache:
"""A cache for tensors indexed by dataset indices."""
tensors: dict[str, torch.Tensor]
def __getitem__(self, indices: torch.Tensor) -> dict[str, torch.Tensor]:
"""Get cached tensors for the given indices."""
return {k: v[indices.long()] for k, v in self.tensors.items()}
def to_disk(self, path: str | pathlib.Path) -> None:
"""Save the cache to disk atomically using temp file and rename."""
...
@classmethod
def from_disk(cls, path: str | pathlib.Path, device: torch.device) -> "TensorCache":
"""Load a cache from disk."""
...
Import
from open_instruct.model_utils import TensorCache
I/O Contract
Constructor
| Parameter | Type | Description |
|---|---|---|
tensors |
dict[str, torch.Tensor] |
Dictionary mapping string keys to tensors. All tensors must have the same size along the first dimension (the dataset size dimension). |
__getitem__
| Parameter | Type | Description |
|---|---|---|
indices |
torch.Tensor |
1D tensor of dataset indices to retrieve. |
| Returns | Type | Description |
| Sliced tensors | dict[str, torch.Tensor] |
Dictionary with the same keys as self.tensors, each value sliced to the given indices.
|
to_disk
| Parameter | Type | Description |
|---|---|---|
path |
str or pathlib.Path |
File path to save the cache. Parent directories are created automatically. |
from_disk
| Parameter | Type | Description |
|---|---|---|
path |
str or pathlib.Path |
File path from which to load the cache. |
device |
torch.device |
Device to place the loaded tensors on. |
| Returns | Type | Description |
| Cache | TensorCache |
A new TensorCache instance with tensors on the specified device.
|
Usage Examples
import torch
from open_instruct.model_utils import TensorCache
# Create a cache with two tensors
cache = TensorCache(tensors={
"chosen_logps": torch.randn(1000),
"rejected_logps": torch.randn(1000),
})
# Retrieve values for a batch of indices
batch_indices = torch.tensor([0, 5, 42, 99])
batch_logps = cache[batch_indices]
# batch_logps["chosen_logps"] -> tensor of shape (4,)
# batch_logps["rejected_logps"] -> tensor of shape (4,)
# Save to disk
cache.to_disk("/tmp/ref_logprobs_cache.pt")
# Load from disk
loaded_cache = TensorCache.from_disk("/tmp/ref_logprobs_cache.pt", device=torch.device("cuda:0"))
Related Pages
Implements Principle
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment