Implementation:Huggingface Trl PPO Dataset Tokenization
| Property | Value |
|---|---|
| Implementation Name | PPO Dataset Tokenization |
| Technology | Huggingface TRL, Datasets |
| Type | Pattern Doc |
| Workflow | PPO RLHF Training |
| Principle | Principle:Huggingface_Trl_PPO_Prompt_Dataset_Preparation |
Overview
Description
The PPO dataset tokenization pattern converts raw text prompts into token ID sequences for online response generation. The dataset is split into training and evaluation subsets, tokenized using the left-padded tokenizer, and stripped of all columns except input_ids. This produces a minimal dataset that the DataCollatorWithPadding can efficiently batch during training.
Usage
This pattern is implemented in the PPO training script before PPOTrainer initialization. The resulting datasets are passed as train_dataset and eval_dataset to the trainer.
Code Reference
Source Location
examples/scripts/ppo/ppo.py lines 128-157
Pattern
from datasets import load_dataset
from accelerate import PartialState
# Load the dataset
dataset = load_dataset(
script_args.dataset_name,
name=script_args.dataset_config,
split=script_args.dataset_train_split,
)
# Split into train and eval
eval_samples = 100
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
dataset_text_field = "prompt"
def prepare_dataset(dataset, tokenizer):
"""Pre-tokenize the dataset before training; only collate during training."""
def tokenize(element):
outputs = tokenizer(
element[dataset_text_field],
padding=False,
)
return {"input_ids": outputs["input_ids"]}
return dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
num_proc=training_args.dataset_num_proc,
)
# Process on main process first for faster data processing
with PartialState().local_main_process_first():
train_dataset = prepare_dataset(train_dataset, tokenizer)
eval_dataset = prepare_dataset(eval_dataset, tokenizer)
Import
from datasets import load_dataset
from accelerate import PartialState
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
| dataset | Dataset | Raw dataset with a text column (e.g., "prompt" or "query") |
| tokenizer | PreTrainedTokenizerBase | Tokenizer with padding_side="left" |
| dataset_text_field | str | Name of the text column to tokenize (default: "prompt") |
| dataset_num_proc | int or None | Number of parallel processes for dataset.map |
Outputs
| Output | Type | Columns | Description |
|---|---|---|---|
| train_dataset | Dataset | input_ids | Tokenized prompts for training (all original columns removed) |
| eval_dataset | Dataset | input_ids | Tokenized prompts for evaluation (last 100 samples) |
Data Flow
| Step | Input | Output | Description |
|---|---|---|---|
| 1. Load | Dataset name | Raw Dataset | Load from Huggingface Hub or local path |
| 2. Split | Raw Dataset | train + eval | Select last 100 samples for eval, rest for train |
| 3. Tokenize | Text strings | Token ID lists | Convert prompt text to input_ids without padding |
| 4. Clean | All columns | input_ids only | Remove original text columns |
Usage Examples
Basic Tokenization
from datasets import load_dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b-deduped", padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
dataset = load_dataset(
"trl-internal-testing/descriptiveness-sentiment-trl-style",
split="descriptiveness",
)
def tokenize(element):
outputs = tokenizer(element["prompt"], padding=False)
return {"input_ids": outputs["input_ids"]}
tokenized = dataset.map(
tokenize,
batched=True,
remove_columns=dataset.column_names,
)
# tokenized[0] = {"input_ids": [123, 456, 789, ...]}
With DataCollatorWithPadding
from transformers import DataCollatorWithPadding
from torch.utils.data import DataLoader
collator = DataCollatorWithPadding(tokenizer)
dataloader = DataLoader(
tokenized,
batch_size=64,
shuffle=True,
collate_fn=collator,
drop_last=True,
)
batch = next(iter(dataloader))
# batch["input_ids"].shape: (64, max_seq_len_in_batch)
# batch["attention_mask"].shape: (64, max_seq_len_in_batch)
# Left-padded: padding tokens appear on the left side