Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Trlx Train SFT

From Leeroopedia


Knowledge Sources
Domains Supervised_Learning, NLP, Training
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for launching supervised fine-tuning of language models provided by the trlx.train() API.

Description

When trlx.train() is called with samples but no rewards and no reward_fn, it enters the SFT path. It creates an AccelerateSFTTrainer, which handles two data formats: plain text strings (using PromptPipeline) or prompt-completion pairs (using DialogStore with masked loss). The trainer computes cross-entropy loss with label shifting and trains via the standard learn() loop with periodic evaluation.

Usage

Call trlx.train() with samples (and no rewards) for SFT training. Provide eval_prompts and metric_fn to monitor generation quality during training.

Code Reference

Source Location

  • Repository: trlx
  • File: trlx/trlx.py
  • Lines: L15-143 (train function, offline/SFT branch at L119-131)
  • File: trlx/trainer/accelerate_sft_trainer.py
  • Lines: L30-97 (AccelerateSFTTrainer)

Signature

def train(
    model_path: Optional[str] = None,
    samples: Optional[List[str]] = None,   # Required for SFT
    eval_prompts: Optional[List[str]] = None,
    metric_fn: Optional[Callable] = None,
    config: Optional[TRLConfig] = None,
    stop_sequences: Optional[List[str]] = [],
) -> AccelerateSFTTrainer:
    """
    Runs supervised fine-tuning when samples are provided without rewards.
    """

Import

import trlx

I/O Contract

Inputs

Name Type Required Description
samples List[str] or List[List[str]] Yes Plain text or [prompt, completion] pairs
eval_prompts List[str] No Prompts for periodic generation evaluation
metric_fn Callable No Evaluation metrics function
config TRLConfig Yes Configuration with SFTConfig method

Outputs

Name Type Description
return AccelerateSFTTrainer Trained SFT trainer instance
checkpoints Files Saved to config.train.checkpoint_dir

Usage Examples

SFT on Positive Reviews

import trlx
from trlx.data.default_configs import default_sft_config
from datasets import load_dataset

# 1. Configure
config = default_sft_config()
config.model.model_path = "gpt2"
config.train.batch_size = 4
config.train.total_steps = 1000

# 2. Load positive reviews as training data
imdb = load_dataset("imdb", split="train")
samples = [review for review, label in zip(imdb["text"], imdb["label"]) if label == 1]

# 3. Launch SFT
trainer = trlx.train(
    samples=samples,
    eval_prompts=["I don't know much about"] * 64,
    config=config,
)
trainer.save_pretrained("sft_model")

SFT on Instruction Pairs (Alpaca Format)

import trlx
from trlx.data.default_configs import default_sft_config

config = default_sft_config()
config = config.evolve(
    train=dict(seq_length=512, batch_size=4, total_steps=2000),
    model=dict(model_path="EleutherAI/gpt-j-6B"),
)

# Prompt-completion pairs → masked loss on prompts
samples = [
    ["Below is an instruction.\n### Instruction:\nWhat is Python?\n### Response:\n",
     "Python is a programming language."],
    ["Below is an instruction.\n### Instruction:\nExplain gravity.\n### Response:\n",
     "Gravity is a fundamental force of nature."],
]

trainer = trlx.train(samples=samples, config=config)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment