Implementation:Allenai Open instruct HFDataLoader
| Type | Class |
|---|---|
| Source | open_instruct/data_loader.py:L67-278
|
| Dependencies | datasets, torch, numpy, olmo_core.data.data_loader |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for distributed, checkpointable iteration over HuggingFace datasets with strided sharding, provided by the Open Instruct library.
Description
HFDataLoader is a data loader class that wraps a HuggingFace Dataset and implements the olmo_core.data.DataLoaderBase interface. It provides:
- Strided sharding across data-parallel ranks for diverse per-rank batches.
- Epoch-based reshuffling with deterministic seeding (base seed + epoch number).
- Index exclusion to remove mastered prompts from future iterations.
- Checkpointing and resumption via
state_dict()andload_state_dict(). - Automatic reshuffling at epoch boundaries when configured.
- Optional collation with a custom collator function.
The dataset must contain an index column for tracking examples across epochs. This is typically added by get_cached_dataset_tulu().
Usage
Import and instantiate this class when you need a distributed data loader for prompt iteration in the GRPO pipeline. It is primarily used inside the DataPreparationActor to feed prompts to vLLM generation engines.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/data_loader.py
Signature
class HFDataLoader(data_loader.DataLoaderBase):
def __init__(
self,
dataset: Dataset,
batch_size: int,
seed: int,
dp_rank: int,
dp_world_size: int,
work_dir: str,
automatic_reshuffle: bool = False,
collator: Callable[[list[dict[str, Any]]], dict[str, Any]] | None = None,
device: torch.device | None = None,
drop_last: bool = True,
fs_local_rank: int | None = None,
) -> None:
Import
from open_instruct.data_loader import HFDataLoader
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
dataset |
Dataset |
HuggingFace Dataset with an index column.
|
batch_size |
int |
Global batch size (divided by dp_world_size per rank).
|
seed |
int |
Base random seed for deterministic shuffling. |
dp_rank |
int |
Rank of the current data-parallel process. |
dp_world_size |
int |
Total number of data-parallel processes. |
work_dir |
str |
Working directory (required by DataLoaderBase). |
automatic_reshuffle |
bool |
Whether to automatically reshuffle when epoch ends. |
collator |
None | Optional function to collate batches of examples. |
device |
None | Device to move tensors to. |
drop_last |
bool |
Whether to drop the last incomplete batch. |
Outputs
| Name | Type | Description |
|---|---|---|
| Batches (via iteration) | dict[str, Any] |
Collated batch dictionaries, each containing a prompt_id field of the form "{epoch}_{index}".
|
state_dict() |
dict |
Checkpoint state containing epoch, batches_processed, and excluded_indices. |
Key Methods
| Method | Description |
|---|---|
__next__() |
Returns next batch; handles epoch rollover and automatic reshuffling. |
reshuffle(epoch=None) |
Reshuffles and reshards the dataset for a new epoch. |
exclude_index(index) |
Excludes a dataset index from future iterations (for mastered prompts). |
state_dict() |
Returns checkpoint state for saving. |
load_state_dict(state_dict) |
Restores data loader state from checkpoint. |
get_mock_batch() |
Returns a dummy batch for dry-run testing. |
Usage Examples
from datasets import Dataset
from open_instruct.data_loader import HFDataLoader
# Create a simple dataset with an index column
dataset = Dataset.from_dict({
"input_ids_prompt": [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
"index": [0, 1, 2, 3],
})
loader = HFDataLoader(
dataset=dataset,
batch_size=2,
seed=42,
dp_rank=0,
dp_world_size=1,
work_dir="/tmp/work",
automatic_reshuffle=True,
collator=lambda examples: {"examples": examples},
)
# Iterate over one epoch
for batch in loader:
print(batch)
# Exclude a mastered prompt
loader.exclude_index(2)
loader.reshuffle()
# Checkpoint and resume
state = loader.state_dict()
# ... later ...
loader.load_state_dict(state)