Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Get Batch Iterator

From Leeroopedia


Knowledge Sources
Domains Data_Pipeline, Preprocessing, NLP
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for iterating over tokenized, batched preference data provided by the direct-preference-optimization repository.

Description

The get_batch_iterator function is the main data pipeline entry point. It loads one or more datasets, flattens prompt-response structures, tokenizes each element, groups into batches, and yields collated tensor dictionaries. It supports both SFT mode (single target) and DPO mode (preference pairs), with configurable epoch and example limits.

Usage

Import this function when setting up data iterators in BasicTrainer.__init__. One iterator is created for training (with epoch limits and shuffling) and one for evaluation (with example limits, no shuffling).

Code Reference

Source Location

Signature

def get_batch_iterator(
    names: List[str],
    tokenizer,
    split: str = 'train',
    batch_size: int = 1,
    shuffle: bool = True,
    max_length: int = 512,
    max_prompt_length: int = 128,
    sft_mode: bool = False,
    n_epochs: Optional[int] = None,
    n_examples: Optional[int] = None,
    seed: int = 0,
    silent: bool = False,
    cache_dir: Optional[str] = None,
) -> Iterator[Dict]:
    """Get an iterator over batches of data.

    Stops after n_epochs or n_examples, whichever comes first.
    """

Import

from preference_datasets import get_batch_iterator

I/O Contract

Inputs

Name Type Required Description
names List[str] Yes Dataset names (e.g., ["hh"], ["hh", "shp"])
tokenizer PreTrainedTokenizer Yes HuggingFace tokenizer for encoding text
split str No "train" or "test" (default "train")
batch_size int No Number of examples per batch (default 1)
shuffle bool No Shuffle data each epoch (default True)
max_length int No Max combined prompt+response length (default 512)
max_prompt_length int No Max prompt length for truncation (default 128)
sft_mode bool No If True, use sft_target only (default False)
n_epochs Optional[int] Conditional Number of epochs; must specify this or n_examples
n_examples Optional[int] Conditional Number of examples; must specify this or n_epochs
seed int No Random seed for shuffling (default 0)
silent bool No Suppress progress bars (default False)
cache_dir Optional[str] No Directory for dataset caching

Outputs

Name Type Description
batch Iterator[Dict] Yields dicts with keys: chosen_input_ids, chosen_attention_mask, chosen_labels, prompt_input_ids, prompt_attention_mask, prompt (text). In DPO mode, also includes rejected_input_ids, rejected_attention_mask, rejected_labels, rejected (text), chosen (text).

Usage Examples

SFT Training Iterator

from preference_datasets import get_batch_iterator

train_iterator = get_batch_iterator(
    names=["hh"],
    tokenizer=tokenizer,
    split='train',
    batch_size=4,
    shuffle=True,
    max_length=512,
    max_prompt_length=256,
    sft_mode=True,
    n_epochs=1,
    seed=0,
    cache_dir=".cache/user",
)

for batch in train_iterator:
    # batch['chosen_input_ids'] shape: (4, seq_len)
    loss = compute_sft_loss(model, batch)

DPO Training Iterator

train_iterator = get_batch_iterator(
    names=["hh"],
    tokenizer=tokenizer,
    split='train',
    batch_size=4,
    shuffle=True,
    max_length=512,
    max_prompt_length=256,
    sft_mode=False,  # DPO mode: iterate over preference pairs
    n_epochs=1,
    seed=0,
    cache_dir=".cache/user",
)

for batch in train_iterator:
    # batch has both chosen_* and rejected_* keys
    loss = compute_dpo_loss(policy, reference, batch)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment