Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct SimplePreferenceCollator

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning from Human Feedback, Reward Modeling, Data Processing
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for batching and padding variable-length chosen/rejected preference pairs into uniform tensors for reward model training, provided by Open Instruct.

Description

The SimplePreferenceCollator is a callable class that serves as the collate_fn for PyTorch DataLoaders when training or evaluating reward models on preference data. It takes a list of individual preference examples (each containing chosen and rejected token ID lists) and produces a single batch dictionary with padded tensors.

The collator enforces a uniform padding length across both chosen and rejected sequences within a batch, computed as the global maximum over all sequence lengths. This design is critical because the training loop concatenates chosen and rejected tensors along the batch dimension for a single efficient forward pass.

All padding is applied from the right (post-padding) using the specified padding token ID. This ensures compatibility with causal transformer architectures that expect real tokens to start from position 0.

Usage

Use this collator whenever creating a DataLoader for reward model training or evaluation. It is passed as the collate_fn parameter to torch.utils.data.DataLoader.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/dataset_transformation.py, lines 1508-1545

Class Definition

class SimplePreferenceCollator:
    def __init__(self, pad_token_id: int):
        """Simple collator for preference dataset (always pad from the RIGHT)"""
        self.pad_token_id = pad_token_id

    def __call__(self, batch: list[dict[str, list[int]]]):
        """the input will have input_ids_chosen, input_ids_rejected"""
        # Find max length in the batch
        max_length_chosen = -1
        max_length_rejected = -1
        for i in range(len(batch)):
            max_length_chosen = max(max_length_chosen, len(batch[i][CHOSEN_INPUT_IDS_KEY]))
            max_length_rejected = max(max_length_rejected, len(batch[i][REJECTED_INPUT_IDS_KEY]))
        max_length = max(max_length_chosen, max_length_rejected)
        assert max_length > 0, "the dataset is empty"

        # Initialize lists to store padded sequences and attention masks
        padded_sequences_chosen = []
        padded_sequences_rejected = []

        for i in range(len(batch)):
            # Calculate padding length
            pad_length_chosen = max_length - len(batch[i][CHOSEN_INPUT_IDS_KEY])
            pad_length_rejected = max_length - len(batch[i][REJECTED_INPUT_IDS_KEY])

            # Pad from the right
            padding_chosen = [self.pad_token_id] * pad_length_chosen
            padding_rejected = [self.pad_token_id] * pad_length_rejected
            padded_sequence_chosen = batch[i][CHOSEN_INPUT_IDS_KEY] + padding_chosen
            padded_sequence_rejected = batch[i][REJECTED_INPUT_IDS_KEY] + padding_rejected
            padded_sequences_chosen.append(padded_sequence_chosen)
            padded_sequences_rejected.append(padded_sequence_rejected)

        # Convert to tensors
        padded_sequences_chosen = torch.tensor(padded_sequences_chosen)
        padded_sequences_rejected = torch.tensor(padded_sequences_rejected)

        return {
            CHOSEN_INPUT_IDS_KEY: padded_sequences_chosen,
            REJECTED_INPUT_IDS_KEY: padded_sequences_rejected,
        }

Import

from open_instruct.dataset_transformation import SimplePreferenceCollator

I/O Contract

Inputs

Name Type Required Description
pad_token_id int Yes (constructor) The token ID to use for padding. Obtained from the tokenizer via tokenizer.pad_token_id.
batch list[dict[str, list[int]]] Yes (call) A list of preference examples. Each example is a dictionary with keys input_ids_chosen (list of int) and input_ids_rejected (list of int), representing the tokenized chosen and rejected completions respectively.

Outputs

Name Type Description
result dict[str, torch.Tensor] A dictionary with two keys: input_ids_chosen of shape (batch_size, max_length) and input_ids_rejected of shape (batch_size, max_length). Both tensors use the same max_length, which is the longest sequence across all chosen and rejected sequences in the batch.

Usage Examples

Basic Usage

from torch.utils.data import DataLoader
from open_instruct.dataset_transformation import SimplePreferenceCollator

# Create collator with the tokenizer's pad token
collator = SimplePreferenceCollator(pad_token_id=tokenizer.pad_token_id)

# Use with a DataLoader
dataloader = DataLoader(
    preference_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collator,
)

# Each batch is a dict with padded tensors
for batch in dataloader:
    chosen_ids = batch["input_ids_chosen"]      # shape: (4, max_len)
    rejected_ids = batch["input_ids_rejected"]   # shape: (4, max_len)
    # Both have the same max_len
    assert chosen_ids.shape == rejected_ids.shape

In Reward Model Training

from torch.utils.data import DataLoader
from open_instruct.dataset_transformation import SimplePreferenceCollator

# Create training dataloader
data_collator = SimplePreferenceCollator(pad_token_id=tokenizer.pad_token_id)
dataloader = DataLoader(
    train_dataset,
    batch_size=args.per_device_train_batch_size,
    shuffle=True,
    collate_fn=data_collator,
)

# Create evaluation dataloader
eval_dataloader = DataLoader(
    eval_dataset,
    batch_size=args.per_device_eval_batch_size,
    shuffle=False,
    collate_fn=data_collator,
)

Examining Collator Output

# Example with sequences of different lengths
batch = [
    {
        "input_ids_chosen": [1, 42, 100, 200, 2],     # length 5
        "input_ids_rejected": [1, 42, 300, 2],          # length 4
    },
    {
        "input_ids_chosen": [1, 55, 2],                  # length 3
        "input_ids_rejected": [1, 55, 600, 700, 800, 2], # length 6
    },
]

collator = SimplePreferenceCollator(pad_token_id=0)
result = collator(batch)

# max_length = max(5, 4, 3, 6) = 6
# result["input_ids_chosen"]:
#   [[1, 42, 100, 200, 2, 0],   # padded from 5 to 6
#    [1, 55,   2,   0, 0, 0]]   # padded from 3 to 6
# result["input_ids_rejected"]:
#   [[1, 42, 300,   2,   0, 0],         # padded from 4 to 6
#    [1, 55, 600, 700, 800, 2]]          # no padding needed (length 6)

Dependencies

Package Module Purpose
torch torch.tensor Converting padded Python lists to PyTorch tensors
open_instruct dataset_transformation Provides the dataset key constants CHOSEN_INPUT_IDS_KEY ("input_ids_chosen") and REJECTED_INPUT_IDS_KEY ("input_ids_rejected")

Implementation Details

The collator implements a straightforward two-pass algorithm:

Pass 1: Length computation

  • Iterates over all examples to find the maximum chosen length and maximum rejected length.
  • Takes the global maximum across both to ensure uniform tensor dimensions.
  • Asserts that the maximum length is positive (the dataset is non-empty).

Pass 2: Padding and collection

  • For each example, computes the padding length as max_length - current_length.
  • Creates a padding list of [pad_token_id] * pad_length.
  • Concatenates the original tokens with the padding (right-padding).
  • Appends to the batch lists.

Tensor conversion

  • Converts the list of lists to torch.Tensor using torch.tensor().
  • Returns a dictionary with the two padded tensors.

The implementation is intentionally simple and does not handle:

  • Attention masks (these are computed later in get_reward() based on the padding token ID).
  • Left-padding (always right-pads, as required by the causal reward model architecture).
  • Truncation (sequences are assumed to already be within length limits from the dataset transformation step).

Related Pages

Implements Principle

Related Implementations

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment