Implementation:CarperAI Trlx Trlx Train SFT
| Knowledge Sources | |
|---|---|
| Domains | Supervised_Learning, NLP, Training |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for launching supervised fine-tuning of language models provided by the trlx.train() API.
Description
When trlx.train() is called with samples but no rewards and no reward_fn, it enters the SFT path. It creates an AccelerateSFTTrainer, which handles two data formats: plain text strings (using PromptPipeline) or prompt-completion pairs (using DialogStore with masked loss). The trainer computes cross-entropy loss with label shifting and trains via the standard learn() loop with periodic evaluation.
Usage
Call trlx.train() with samples (and no rewards) for SFT training. Provide eval_prompts and metric_fn to monitor generation quality during training.
Code Reference
Source Location
- Repository: trlx
- File: trlx/trlx.py
- Lines: L15-143 (train function, offline/SFT branch at L119-131)
- File: trlx/trainer/accelerate_sft_trainer.py
- Lines: L30-97 (AccelerateSFTTrainer)
Signature
def train(
model_path: Optional[str] = None,
samples: Optional[List[str]] = None, # Required for SFT
eval_prompts: Optional[List[str]] = None,
metric_fn: Optional[Callable] = None,
config: Optional[TRLConfig] = None,
stop_sequences: Optional[List[str]] = [],
) -> AccelerateSFTTrainer:
"""
Runs supervised fine-tuning when samples are provided without rewards.
"""
Import
import trlx
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| samples | List[str] or List[List[str]] | Yes | Plain text or [prompt, completion] pairs |
| eval_prompts | List[str] | No | Prompts for periodic generation evaluation |
| metric_fn | Callable | No | Evaluation metrics function |
| config | TRLConfig | Yes | Configuration with SFTConfig method |
Outputs
| Name | Type | Description |
|---|---|---|
| return | AccelerateSFTTrainer | Trained SFT trainer instance |
| checkpoints | Files | Saved to config.train.checkpoint_dir |
Usage Examples
SFT on Positive Reviews
import trlx
from trlx.data.default_configs import default_sft_config
from datasets import load_dataset
# 1. Configure
config = default_sft_config()
config.model.model_path = "gpt2"
config.train.batch_size = 4
config.train.total_steps = 1000
# 2. Load positive reviews as training data
imdb = load_dataset("imdb", split="train")
samples = [review for review, label in zip(imdb["text"], imdb["label"]) if label == 1]
# 3. Launch SFT
trainer = trlx.train(
samples=samples,
eval_prompts=["I don't know much about"] * 64,
config=config,
)
trainer.save_pretrained("sft_model")
SFT on Instruction Pairs (Alpaca Format)
import trlx
from trlx.data.default_configs import default_sft_config
config = default_sft_config()
config = config.evolve(
train=dict(seq_length=512, batch_size=4, total_steps=2000),
model=dict(model_path="EleutherAI/gpt-j-6B"),
)
# Prompt-completion pairs → masked loss on prompts
samples = [
["Below is an instruction.\n### Instruction:\nWhat is Python?\n### Response:\n",
"Python is a programming language."],
["Below is an instruction.\n### Instruction:\nExplain gravity.\n### Response:\n",
"Gravity is a fundamental force of nature."],
]
trainer = trlx.train(samples=samples, config=config)