Implementation:CarperAI Trlx Trlx Train Online
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Training |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for launching online PPO training of language models provided by the trlx.train() API.
Description
The trlx.train() function is the unified entry point for all trlx training methods. When called with a reward_fn argument, it enters the online RL path: it creates an AcceleratePPOTrainer, sets up a PromptPipeline for generating completions, and starts the PPO training loop via trainer.learn(). The function dispatches to the correct trainer class based on config.train.trainer and handles the complete lifecycle from setup to training completion.
Usage
Call trlx.train() with a reward_fn and prompts list when you want to run online PPO training. The config should specify AcceleratePPOTrainer and PPOConfig. The function returns the trained trainer instance.
Code Reference
Source Location
- Repository: trlx
- File: trlx/trlx.py
- Lines: L15-143
Signature
def train(
model_path: Optional[str] = None,
reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None,
dataset: Optional[Iterable[Tuple[str, float]]] = None,
samples: Optional[List[str]] = None,
rewards: Optional[List[float]] = None,
prompts: Optional[List[str]] = None,
eval_prompts: Optional[List[str]] = None,
metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,
config: Optional[TRLConfig] = None,
stop_sequences: Optional[List[str]] = [],
) -> AcceleratePPOTrainer:
"""
Runs online RL training when reward_fn is provided.
Args:
model_path: Path to HuggingFace model (overrides config.model.model_path).
reward_fn: Callable scoring generated text batches.
prompts: Prompts for generation during training.
eval_prompts: Prompts for periodic evaluation.
metric_fn: Optional evaluation metrics function.
config: TRLConfig with PPOConfig method.
stop_sequences: Strings to trim generations at.
Returns:
Trained AcceleratePPOTrainer instance.
"""
Import
import trlx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| reward_fn | Callable | Yes (for online) | Function scoring batches of generated text |
| prompts | List[str] or List[Dict] | No | Training prompts (defaults to BOS tokens) |
| eval_prompts | List[str] or List[Dict] | No | Evaluation prompts (defaults to subset of prompts) |
| config | TRLConfig | Yes | Configuration with PPOConfig method |
| metric_fn | Callable | No | Optional evaluation metrics function |
| stop_sequences | List[str] | No | Strings to trim generations at |
Outputs
| Name | Type | Description |
|---|---|---|
| return | AcceleratePPOTrainer | Trained trainer with model, tokenizer, and training state |
| checkpoints | Files | Saved to config.train.checkpoint_dir at config.train.checkpoint_interval |
| logs | Dict | Training metrics logged to wandb/tensorboard |
Usage Examples
PPO Sentiment Training
import trlx
from trlx.data.default_configs import default_ppo_config
from transformers import pipeline
from datasets import load_dataset
# 1. Configure
config = default_ppo_config()
config.model.model_path = "lvwerra/gpt2-imdb"
# 2. Define reward function
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)
def reward_fn(samples, **kwargs):
outputs = sentiment_fn(samples, top_k=None, function_to_apply="none", batch_size=16)
return [output[1]["score"] for output in outputs]
# 3. Load prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
# 4. Launch training
trainer = trlx.train(
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about"] * 64,
config=config,
)
PPO with Dict Prompts (Delta Rewards)
import trlx
# Dict prompts pass metadata to reward_fn
prompts = [
{"prompt": "Summarize: ...", "original_output": "reference text"},
]
def reward_fn(samples, prompts, original_output, **kwargs):
rewards = compute_reward(samples)
original_rewards = compute_reward(
[p + o for p, o in zip(prompts, original_output)]
)
return (rewards - original_rewards).tolist()
trainer = trlx.train(
reward_fn=reward_fn,
prompts=prompts,
config=config,
stop_sequences=["Human:", "Assistant:"],
)