Implementation:CarperAI Trlx PromptPipeline
| Knowledge Sources | |
|---|---|
| Domains | Data_Pipeline, NLP, Tokenization |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for tokenizing and batching text prompts for generation provided by the trlx pipeline system.
Description
PromptPipeline is a registered data pipeline class that tokenizes raw text prompts and creates DataLoaders for efficient batch generation. It supports both simple string lists and dict-format prompts that carry metadata through to reward functions. The pipeline handles truncation, attention mask creation, and custom collation with left-padding for generation compatibility.
Usage
PromptPipeline is automatically instantiated by trlx.train() for online PPO prompts and evaluation prompts. It is registered via the @register_datapipeline decorator and dispatched by string name from TRLConfig.train.pipeline.
Code Reference
Source Location
- Repository: trlx
- File: trlx/pipeline/offline_pipeline.py
- Lines: L118-188
Signature
@register_datapipeline
class PromptPipeline(BasePipeline):
def __init__(
self,
prompts: Union[List[Dict[str, Any]], List[str]],
max_prompt_length: int,
tokenizer: PreTrainedTokenizer,
add_special_tokens: bool = False,
):
"""
Args:
prompts: List of raw text prompts or dicts with required "prompt" key
and optional extra metadata keys.
max_prompt_length: Maximum prompt length in tokens. Prompts exceeding
this are truncated.
tokenizer: HuggingFace tokenizer for encoding prompts.
add_special_tokens: Whether to add special tokens (True for seq2seq).
"""
...
def __getitem__(self, ix: int) -> dict:
"""Returns tokenized prompt with input_ids, attention_mask, and metadata."""
...
def __len__(self) -> int:
...
def create_loader(
self,
batch_size: int,
shuffle: bool = False,
sampler=None,
drop_last: bool = False,
) -> DataLoader:
"""Create a DataLoader with left-padding collation for generation."""
...
Import
from trlx.pipeline.offline_pipeline import PromptPipeline
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| prompts | Union[List[Dict], List[str]] | Yes | Raw text prompts or dicts with "prompt" key and metadata |
| max_prompt_length | int | Yes | Maximum token length for prompts (seq_length - max_new_tokens) |
| tokenizer | PreTrainedTokenizer | Yes | HuggingFace tokenizer for encoding |
| add_special_tokens | bool | No | Whether to add special tokens (default False, True for seq2seq) |
Outputs
| Name | Type | Description |
|---|---|---|
| __getitem__ | dict | Dict with input_ids, attention_mask, and any extra metadata keys |
| create_loader() | DataLoader | Batched DataLoader with left-padded tokenized prompts |
Usage Examples
Direct Instantiation
from trlx.pipeline.offline_pipeline import PromptPipeline
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
prompts = ["Once upon a time", "The quick brown fox", "In a galaxy far"]
pipeline = PromptPipeline(
prompts=prompts,
max_prompt_length=64,
tokenizer=tokenizer,
)
loader = pipeline.create_loader(batch_size=2, shuffle=True)
for batch in loader:
print(batch["input_ids"].shape)
print(batch["attention_mask"].shape)
With Metadata for Reward Function
# Dict-format prompts pass extra keys to reward_fn
prompts = [
{"prompt": "Summarize: ...", "original_output": "Reference summary"},
{"prompt": "Summarize: ...", "original_output": "Another reference"},
]
pipeline = PromptPipeline(
prompts=prompts,
max_prompt_length=500,
tokenizer=tokenizer,
)
# During PPO, reward_fn receives:
# reward_fn(samples, prompts, outputs, original_output=[...])