Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx PromptPipeline

From Leeroopedia


Knowledge Sources
Domains Data_Pipeline, NLP, Tokenization
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for tokenizing and batching text prompts for generation provided by the trlx pipeline system.

Description

PromptPipeline is a registered data pipeline class that tokenizes raw text prompts and creates DataLoaders for efficient batch generation. It supports both simple string lists and dict-format prompts that carry metadata through to reward functions. The pipeline handles truncation, attention mask creation, and custom collation with left-padding for generation compatibility.

Usage

PromptPipeline is automatically instantiated by trlx.train() for online PPO prompts and evaluation prompts. It is registered via the @register_datapipeline decorator and dispatched by string name from TRLConfig.train.pipeline.

Code Reference

Source Location

  • Repository: trlx
  • File: trlx/pipeline/offline_pipeline.py
  • Lines: L118-188

Signature

@register_datapipeline
class PromptPipeline(BasePipeline):
    def __init__(
        self,
        prompts: Union[List[Dict[str, Any]], List[str]],
        max_prompt_length: int,
        tokenizer: PreTrainedTokenizer,
        add_special_tokens: bool = False,
    ):
        """
        Args:
            prompts: List of raw text prompts or dicts with required "prompt" key
                and optional extra metadata keys.
            max_prompt_length: Maximum prompt length in tokens. Prompts exceeding
                this are truncated.
            tokenizer: HuggingFace tokenizer for encoding prompts.
            add_special_tokens: Whether to add special tokens (True for seq2seq).
        """
        ...

    def __getitem__(self, ix: int) -> dict:
        """Returns tokenized prompt with input_ids, attention_mask, and metadata."""
        ...

    def __len__(self) -> int:
        ...

    def create_loader(
        self,
        batch_size: int,
        shuffle: bool = False,
        sampler=None,
        drop_last: bool = False,
    ) -> DataLoader:
        """Create a DataLoader with left-padding collation for generation."""
        ...

Import

from trlx.pipeline.offline_pipeline import PromptPipeline

I/O Contract

Inputs

Name Type Required Description
prompts Union[List[Dict], List[str]] Yes Raw text prompts or dicts with "prompt" key and metadata
max_prompt_length int Yes Maximum token length for prompts (seq_length - max_new_tokens)
tokenizer PreTrainedTokenizer Yes HuggingFace tokenizer for encoding
add_special_tokens bool No Whether to add special tokens (default False, True for seq2seq)

Outputs

Name Type Description
__getitem__ dict Dict with input_ids, attention_mask, and any extra metadata keys
create_loader() DataLoader Batched DataLoader with left-padded tokenized prompts

Usage Examples

Direct Instantiation

from trlx.pipeline.offline_pipeline import PromptPipeline
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

prompts = ["Once upon a time", "The quick brown fox", "In a galaxy far"]

pipeline = PromptPipeline(
    prompts=prompts,
    max_prompt_length=64,
    tokenizer=tokenizer,
)

loader = pipeline.create_loader(batch_size=2, shuffle=True)
for batch in loader:
    print(batch["input_ids"].shape)
    print(batch["attention_mask"].shape)

With Metadata for Reward Function

# Dict-format prompts pass extra keys to reward_fn
prompts = [
    {"prompt": "Summarize: ...", "original_output": "Reference summary"},
    {"prompt": "Summarize: ...", "original_output": "Another reference"},
]

pipeline = PromptPipeline(
    prompts=prompts,
    max_prompt_length=500,
    tokenizer=tokenizer,
)

# During PPO, reward_fn receives:
# reward_fn(samples, prompts, outputs, original_output=[...])

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment