Implementation:Eric mitchell Direct preference optimization Get Batch Iterator
| Knowledge Sources | |
|---|---|
| Domains | Data_Pipeline, Preprocessing, NLP |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for iterating over tokenized, batched preference data provided by the direct-preference-optimization repository.
Description
The get_batch_iterator function is the main data pipeline entry point. It loads one or more datasets, flattens prompt-response structures, tokenizes each element, groups into batches, and yields collated tensor dictionaries. It supports both SFT mode (single target) and DPO mode (preference pairs), with configurable epoch and example limits.
Usage
Import this function when setting up data iterators in BasicTrainer.__init__. One iterator is created for training (with epoch limits and shuffling) and one for evaluation (with example limits, no shuffling).
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: preference_datasets.py
- Lines: 280-371
Signature
def get_batch_iterator(
names: List[str],
tokenizer,
split: str = 'train',
batch_size: int = 1,
shuffle: bool = True,
max_length: int = 512,
max_prompt_length: int = 128,
sft_mode: bool = False,
n_epochs: Optional[int] = None,
n_examples: Optional[int] = None,
seed: int = 0,
silent: bool = False,
cache_dir: Optional[str] = None,
) -> Iterator[Dict]:
"""Get an iterator over batches of data.
Stops after n_epochs or n_examples, whichever comes first.
"""
Import
from preference_datasets import get_batch_iterator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| names | List[str] | Yes | Dataset names (e.g., ["hh"], ["hh", "shp"]) |
| tokenizer | PreTrainedTokenizer | Yes | HuggingFace tokenizer for encoding text |
| split | str | No | "train" or "test" (default "train") |
| batch_size | int | No | Number of examples per batch (default 1) |
| shuffle | bool | No | Shuffle data each epoch (default True) |
| max_length | int | No | Max combined prompt+response length (default 512) |
| max_prompt_length | int | No | Max prompt length for truncation (default 128) |
| sft_mode | bool | No | If True, use sft_target only (default False) |
| n_epochs | Optional[int] | Conditional | Number of epochs; must specify this or n_examples |
| n_examples | Optional[int] | Conditional | Number of examples; must specify this or n_epochs |
| seed | int | No | Random seed for shuffling (default 0) |
| silent | bool | No | Suppress progress bars (default False) |
| cache_dir | Optional[str] | No | Directory for dataset caching |
Outputs
| Name | Type | Description |
|---|---|---|
| batch | Iterator[Dict] | Yields dicts with keys: chosen_input_ids, chosen_attention_mask, chosen_labels, prompt_input_ids, prompt_attention_mask, prompt (text). In DPO mode, also includes rejected_input_ids, rejected_attention_mask, rejected_labels, rejected (text), chosen (text). |
Usage Examples
SFT Training Iterator
from preference_datasets import get_batch_iterator
train_iterator = get_batch_iterator(
names=["hh"],
tokenizer=tokenizer,
split='train',
batch_size=4,
shuffle=True,
max_length=512,
max_prompt_length=256,
sft_mode=True,
n_epochs=1,
seed=0,
cache_dir=".cache/user",
)
for batch in train_iterator:
# batch['chosen_input_ids'] shape: (4, seq_len)
loss = compute_sft_loss(model, batch)
DPO Training Iterator
train_iterator = get_batch_iterator(
names=["hh"],
tokenizer=tokenizer,
split='train',
batch_size=4,
shuffle=True,
max_length=512,
max_prompt_length=256,
sft_mode=False, # DPO mode: iterate over preference pairs
n_epochs=1,
seed=0,
cache_dir=".cache/user",
)
for batch in train_iterator:
# batch has both chosen_* and rejected_* keys
loss = compute_dpo_loss(policy, reference, batch)