Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct TensorCache

From Leeroopedia


Component Type Dataclass
Source open_instruct/model_utils.py (Lines 48-72)
Repository Open Instruct
Dependencies torch, pathlib, tempfile
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for storing, indexing, and persisting named tensors to disk, provided by the Open Instruct library.

Description

TensorCache is a Python dataclass that wraps a dictionary of named torch.Tensor objects and provides three core capabilities:

  • Index-based retrieval (__getitem__): Accepts a torch.Tensor of indices and returns a dictionary of sliced tensors, selecting the same indices from every stored tensor. This enables batch-level lookups during training.
  • Atomic disk persistence (to_disk): Saves all tensors to a file using torch.save(). The write is atomic: data is first written to a temporary file in the same directory, then renamed to the target path. Tensors are moved to CPU before saving to avoid device-specific serialization issues.
  • Disk loading (from_disk): A classmethod that loads a previously saved cache from disk, placing the tensors on the specified device. Uses weights_only=True for safe deserialization.

Usage

Import TensorCache when you need to cache precomputed tensor values (such as reference model log-probabilities) for reuse during a training loop, or when persisting such caches across training runs.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/model_utils.py (Lines 48-72)

Signature

@dataclass
class TensorCache:
    """A cache for tensors indexed by dataset indices."""

    tensors: dict[str, torch.Tensor]

    def __getitem__(self, indices: torch.Tensor) -> dict[str, torch.Tensor]:
        """Get cached tensors for the given indices."""
        return {k: v[indices.long()] for k, v in self.tensors.items()}

    def to_disk(self, path: str | pathlib.Path) -> None:
        """Save the cache to disk atomically using temp file and rename."""
        ...

    @classmethod
    def from_disk(cls, path: str | pathlib.Path, device: torch.device) -> "TensorCache":
        """Load a cache from disk."""
        ...

Import

from open_instruct.model_utils import TensorCache

I/O Contract

Constructor

Parameter Type Description
tensors dict[str, torch.Tensor] Dictionary mapping string keys to tensors. All tensors must have the same size along the first dimension (the dataset size dimension).

__getitem__

Parameter Type Description
indices torch.Tensor 1D tensor of dataset indices to retrieve.
Returns Type Description
Sliced tensors dict[str, torch.Tensor] Dictionary with the same keys as self.tensors, each value sliced to the given indices.

to_disk

Parameter Type Description
path str or pathlib.Path File path to save the cache. Parent directories are created automatically.

from_disk

Parameter Type Description
path str or pathlib.Path File path from which to load the cache.
device torch.device Device to place the loaded tensors on.
Returns Type Description
Cache TensorCache A new TensorCache instance with tensors on the specified device.

Usage Examples

import torch
from open_instruct.model_utils import TensorCache

# Create a cache with two tensors
cache = TensorCache(tensors={
    "chosen_logps": torch.randn(1000),
    "rejected_logps": torch.randn(1000),
})

# Retrieve values for a batch of indices
batch_indices = torch.tensor([0, 5, 42, 99])
batch_logps = cache[batch_indices]
# batch_logps["chosen_logps"] -> tensor of shape (4,)
# batch_logps["rejected_logps"] -> tensor of shape (4,)

# Save to disk
cache.to_disk("/tmp/ref_logprobs_cache.pt")

# Load from disk
loaded_cache = TensorCache.from_disk("/tmp/ref_logprobs_cache.pt", device=torch.device("cuda:0"))

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment