Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl RewardTrainer Prepare Dataset

From Leeroopedia


Property Value
Implementation Name RewardTrainer Prepare Dataset
Technology Huggingface TRL
Type API Doc
Workflow Reward Model Training
Principle Principle:Huggingface_Trl_Reward_Preference_Dataset_Loading

Overview

Description

The _prepare_dataset method of RewardTrainer transforms raw preference datasets into tokenized tensors ready for training. The DataCollatorForPreference class handles dynamic padding and batch assembly, stacking chosen and rejected sequences into a single batch tensor. Together, these components form the complete data pipeline for reward model training.

Usage

Dataset preparation is called automatically during RewardTrainer.__init__ for both training and evaluation datasets. Pre-tokenized datasets (containing chosen_input_ids and rejected_input_ids columns) bypass tokenization but still undergo length filtering.

Code Reference

Source Location

  • _prepare_dataset: trl/trainer/reward_trainer.py lines 477-561
  • DataCollatorForPreference: trl/trainer/reward_trainer.py lines 89-177

Signature

def _prepare_dataset(
    self,
    dataset: Dataset | IterableDataset,
    processing_class: PreTrainedTokenizerBase,
    args: RewardConfig,
    dataset_name: str,
) -> Dataset | IterableDataset:
    """
    Prepare a preference dataset for reward model training.

    Steps:
    1. Check if dataset is already tokenized (has chosen_input_ids/rejected_input_ids).
    2. If not tokenized:
       a. Add EOS tokens to non-conversational text.
       b. Tokenize chosen and rejected responses.
    3. Filter samples exceeding max_length.

    Returns the processed dataset.
    """
@dataclass
class DataCollatorForPreference(DataCollatorMixin):
    """
    Data collator for preference data. Stacks chosen and rejected
    sequences into a single padded batch tensor.

    The first half of the batch corresponds to chosen sequences,
    the second half to rejected sequences.
    """
    pad_token_id: int
    pad_to_multiple_of: int | None = None
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        # Returns: {"input_ids": Tensor, "attention_mask": Tensor, "margin": Tensor (optional)}

Import

from trl.trainer.reward_trainer import DataCollatorForPreference

I/O Contract

_prepare_dataset Inputs

Parameter Type Description
dataset Dataset or IterableDataset Raw preference dataset with "chosen"/"rejected" columns or pre-tokenized "chosen_input_ids"/"rejected_input_ids"
processing_class PreTrainedTokenizerBase Tokenizer for encoding text to token IDs
args RewardConfig Training configuration containing max_length, dataset_num_proc, etc.
dataset_name str Name used for progress bar descriptions ("train" or "eval")

_prepare_dataset Outputs

Output Type Description
processed_dataset Dataset or IterableDataset Dataset with "chosen_input_ids" and "rejected_input_ids" columns, filtered by max_length

DataCollatorForPreference Inputs

Parameter Type Description
examples list[dict] List of dicts, each containing "chosen_input_ids", "rejected_input_ids", and optionally "margin"

DataCollatorForPreference Outputs

Output Type Shape Description
input_ids torch.Tensor (2*N, max_seq_len) Padded token IDs; first N rows are chosen, last N are rejected
attention_mask torch.Tensor (2*N, max_seq_len) Binary mask indicating non-padding positions
margin torch.Tensor (N,) Optional preference margin values

Usage Examples

Using DataCollatorForPreference Directly

from trl.trainer.reward_trainer import DataCollatorForPreference

collator = DataCollatorForPreference(pad_token_id=0)

examples = [
    {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
    {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
]
batch = collator(examples)
# batch["input_ids"].shape: (4, 3)  -- 2 chosen + 2 rejected, padded to max length 3
# batch["attention_mask"].shape: (4, 3)

With Margin Annotations

examples = [
    {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5},
    {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0},
]
batch = collator(examples)
# batch["margin"]: tensor([0.5, 0.0])

Expected Dataset Formats

# Standard text format
dataset_standard = Dataset.from_dict({
    "chosen": ["The answer is 42.", "Paris is the capital."],
    "rejected": ["I don't know.", "London is the capital."],
})

# With explicit prompt
dataset_explicit_prompt = Dataset.from_dict({
    "prompt": ["What is the answer? ", "What is the capital of France? "],
    "chosen": ["The answer is 42.", "Paris is the capital."],
    "rejected": ["I don't know.", "London is the capital."],
})

# Pre-tokenized format (skips tokenization)
dataset_tokenized = Dataset.from_dict({
    "chosen_input_ids": [[1, 2, 3], [4, 5, 6]],
    "rejected_input_ids": [[7, 8], [9, 10, 11]],
})

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment