Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Trlx Train Online

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Training
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for launching online PPO training of language models provided by the trlx.train() API.

Description

The trlx.train() function is the unified entry point for all trlx training methods. When called with a reward_fn argument, it enters the online RL path: it creates an AcceleratePPOTrainer, sets up a PromptPipeline for generating completions, and starts the PPO training loop via trainer.learn(). The function dispatches to the correct trainer class based on config.train.trainer and handles the complete lifecycle from setup to training completion.

Usage

Call trlx.train() with a reward_fn and prompts list when you want to run online PPO training. The config should specify AcceleratePPOTrainer and PPOConfig. The function returns the trained trainer instance.

Code Reference

Source Location

  • Repository: trlx
  • File: trlx/trlx.py
  • Lines: L15-143

Signature

def train(
    model_path: Optional[str] = None,
    reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None,
    dataset: Optional[Iterable[Tuple[str, float]]] = None,
    samples: Optional[List[str]] = None,
    rewards: Optional[List[float]] = None,
    prompts: Optional[List[str]] = None,
    eval_prompts: Optional[List[str]] = None,
    metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,
    config: Optional[TRLConfig] = None,
    stop_sequences: Optional[List[str]] = [],
) -> AcceleratePPOTrainer:
    """
    Runs online RL training when reward_fn is provided.

    Args:
        model_path: Path to HuggingFace model (overrides config.model.model_path).
        reward_fn: Callable scoring generated text batches.
        prompts: Prompts for generation during training.
        eval_prompts: Prompts for periodic evaluation.
        metric_fn: Optional evaluation metrics function.
        config: TRLConfig with PPOConfig method.
        stop_sequences: Strings to trim generations at.

    Returns:
        Trained AcceleratePPOTrainer instance.
    """

Import

import trlx

I/O Contract

Inputs

Name Type Required Description
reward_fn Callable Yes (for online) Function scoring batches of generated text
prompts List[str] or List[Dict] No Training prompts (defaults to BOS tokens)
eval_prompts List[str] or List[Dict] No Evaluation prompts (defaults to subset of prompts)
config TRLConfig Yes Configuration with PPOConfig method
metric_fn Callable No Optional evaluation metrics function
stop_sequences List[str] No Strings to trim generations at

Outputs

Name Type Description
return AcceleratePPOTrainer Trained trainer with model, tokenizer, and training state
checkpoints Files Saved to config.train.checkpoint_dir at config.train.checkpoint_interval
logs Dict Training metrics logged to wandb/tensorboard

Usage Examples

PPO Sentiment Training

import trlx
from trlx.data.default_configs import default_ppo_config
from transformers import pipeline
from datasets import load_dataset

# 1. Configure
config = default_ppo_config()
config.model.model_path = "lvwerra/gpt2-imdb"

# 2. Define reward function
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)

def reward_fn(samples, **kwargs):
    outputs = sentiment_fn(samples, top_k=None, function_to_apply="none", batch_size=16)
    return [output[1]["score"] for output in outputs]

# 3. Load prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

# 4. Launch training
trainer = trlx.train(
    reward_fn=reward_fn,
    prompts=prompts,
    eval_prompts=["I don't know much about"] * 64,
    config=config,
)

PPO with Dict Prompts (Delta Rewards)

import trlx

# Dict prompts pass metadata to reward_fn
prompts = [
    {"prompt": "Summarize: ...", "original_output": "reference text"},
]

def reward_fn(samples, prompts, original_output, **kwargs):
    rewards = compute_reward(samples)
    original_rewards = compute_reward(
        [p + o for p, o in zip(prompts, original_output)]
    )
    return (rewards - original_rewards).tolist()

trainer = trlx.train(
    reward_fn=reward_fn,
    prompts=prompts,
    config=config,
    stop_sequences=["Human:", "Assistant:"],
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment