Implementation:CarperAI Trlx Trlx Train Offline
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Offline_RL, Training |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for launching offline ILQL training of language models provided by the trlx.train() API.
Description
When trlx.train() is called with samples and rewards arguments (and no reward_fn), it enters the offline RL path. It creates an AccelerateILQLTrainer, calls trainer.make_experience() to convert the samples and rewards into ILQLRolloutStorage with tokenized sequences and reward/state/action annotations, then starts training via trainer.learn(). The ILQL trainer computes combined Q-value, value, and policy losses over the stored experience.
Usage
Call trlx.train() with samples, rewards, and optionally eval_prompts and metric_fn for offline RL training. The config should specify AccelerateILQLTrainer and ILQLConfig.
Code Reference
Source Location
- Repository: trlx
- File: trlx/trlx.py
- Lines: L15-143 (train function, offline branch at L119-131)
Signature
def train(
model_path: Optional[str] = None,
reward_fn: Optional[Callable] = None, # Must be None for offline
samples: Optional[List[str]] = None, # Required for offline
rewards: Optional[List[float]] = None, # Required for offline
eval_prompts: Optional[List[str]] = None,
metric_fn: Optional[Callable] = None,
config: Optional[TRLConfig] = None,
stop_sequences: Optional[List[str]] = [],
) -> AccelerateILQLTrainer:
"""
Runs offline RL training when samples and rewards are provided.
"""
Import
import trlx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| samples | List[str] or List[List[str]] | Yes | Text samples or [prompt, completion] pairs |
| rewards | List[float] | Yes | Scalar reward per sample (must match len(samples)) |
| eval_prompts | List[str] | No | Prompts for periodic Q-value guided generation evaluation |
| metric_fn | Callable | No | Evaluation metrics function |
| config | TRLConfig | Yes | Configuration with ILQLConfig method |
Outputs
| Name | Type | Description |
|---|---|---|
| return | AccelerateILQLTrainer | Trained trainer with ILQL model (Q/V heads) |
| checkpoints | Files | Saved to config.train.checkpoint_dir |
Usage Examples
ILQL Sentiment Training
import trlx
from trlx.data.default_configs import default_ilql_config
from datasets import load_dataset
from transformers import pipeline
# 1. Configure
config = default_ilql_config()
config.train.batch_size = 32
# 2. Prepare reward-labeled dataset
imdb = load_dataset("imdb", split="train")
samples = imdb["text"]
rewards = [float(label) for label in imdb["label"]] # 0.0 or 1.0
# 3. Define evaluation metrics
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1)
def metric_fn(samples, **kwargs):
outputs = sentiment_fn(samples, batch_size=16)
return {"sentiments": [o["score"] for o in outputs]}
# 4. Launch offline training
trainer = trlx.train(
samples=samples,
rewards=rewards,
eval_prompts=["I don't know much about"] * 64,
metric_fn=metric_fn,
config=config,
)
ILQL with Dialogue Pairs
import trlx
from datasets import load_dataset
# Prompt-completion format for dialogue
dataset = load_dataset("Dahoas/full-hh-rlhf")
samples = [[row["prompt"], row["chosen"]] for row in dataset["train"]]
rewards = [1.0] * len(samples) # All chosen responses get reward 1
trainer = trlx.train(
samples=samples,
rewards=rewards,
config=config,
)