Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct HFDataLoader

From Leeroopedia


Type Class
Source open_instruct/data_loader.py:L67-278
Dependencies datasets, torch, numpy, olmo_core.data.data_loader
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for distributed, checkpointable iteration over HuggingFace datasets with strided sharding, provided by the Open Instruct library.

Description

HFDataLoader is a data loader class that wraps a HuggingFace Dataset and implements the olmo_core.data.DataLoaderBase interface. It provides:

  • Strided sharding across data-parallel ranks for diverse per-rank batches.
  • Epoch-based reshuffling with deterministic seeding (base seed + epoch number).
  • Index exclusion to remove mastered prompts from future iterations.
  • Checkpointing and resumption via state_dict() and load_state_dict().
  • Automatic reshuffling at epoch boundaries when configured.
  • Optional collation with a custom collator function.

The dataset must contain an index column for tracking examples across epochs. This is typically added by get_cached_dataset_tulu().

Usage

Import and instantiate this class when you need a distributed data loader for prompt iteration in the GRPO pipeline. It is primarily used inside the DataPreparationActor to feed prompts to vLLM generation engines.

Code Reference

Source Location

Signature

class HFDataLoader(data_loader.DataLoaderBase):
    def __init__(
        self,
        dataset: Dataset,
        batch_size: int,
        seed: int,
        dp_rank: int,
        dp_world_size: int,
        work_dir: str,
        automatic_reshuffle: bool = False,
        collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None,
        device: torch.device | None = None,
        drop_last: bool = True,
        fs_local_rank: int | None = None,
    ) -> None:

Import

from open_instruct.data_loader import HFDataLoader

I/O Contract

Inputs

Name Type Description
dataset Dataset HuggingFace Dataset with an index column.
batch_size int Global batch size (divided by dp_world_size per rank).
seed int Base random seed for deterministic shuffling.
dp_rank int Rank of the current data-parallel process.
dp_world_size int Total number of data-parallel processes.
work_dir str Working directory (required by DataLoaderBase).
automatic_reshuffle bool Whether to automatically reshuffle when epoch ends.
collator None Optional function to collate batches of examples.
device None Device to move tensors to.
drop_last bool Whether to drop the last incomplete batch.

Outputs

Name Type Description
Batches (via iteration) dict[str, Any] Collated batch dictionaries, each containing a prompt_id field of the form "{epoch}_{index}".
state_dict() dict Checkpoint state containing epoch, batches_processed, and excluded_indices.

Key Methods

Method Description
__next__() Returns next batch; handles epoch rollover and automatic reshuffling.
reshuffle(epoch=None) Reshuffles and reshards the dataset for a new epoch.
exclude_index(index) Excludes a dataset index from future iterations (for mastered prompts).
state_dict() Returns checkpoint state for saving.
load_state_dict(state_dict) Restores data loader state from checkpoint.
get_mock_batch() Returns a dummy batch for dry-run testing.

Usage Examples

from datasets import Dataset
from open_instruct.data_loader import HFDataLoader

# Create a simple dataset with an index column
dataset = Dataset.from_dict({
    "input_ids_prompt": [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
    "index": [0, 1, 2, 3],
})

loader = HFDataLoader(
    dataset=dataset,
    batch_size=2,
    seed=42,
    dp_rank=0,
    dp_world_size=1,
    work_dir="/tmp/work",
    automatic_reshuffle=True,
    collator=lambda examples: {"examples": examples},
)

# Iterate over one epoch
for batch in loader:
    print(batch)

# Exclude a mastered prompt
loader.exclude_index(2)
loader.reshuffle()

# Checkpoint and resume
state = loader.state_dict()
# ... later ...
loader.load_state_dict(state)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment