Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:OpenRLHF OpenRLHF SFTTrainer

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for supervised fine-tuning of language models provided by OpenRLHF.

Description

The SFTTrainer class implements the complete SFT training loop with DeepSpeed integration. It wraps the model forward pass (computing per-token log-probabilities via Actor.forward), applies the SFTLoss with loss masking, handles gradient accumulation and clipping, manages evaluation, checkpointing (both DeepSpeed and HuggingFace formats), and logging to W&B or TensorBoard.

Key features: checkpoint recovery via consumed_samples tracking, MoE auxiliary loss support, ring attention integration, and configurable evaluation/save frequencies.

Usage

Instantiate after creating the model, optimizer, scheduler, and dataloaders. Call fit() to run the full training loop. Used in SFT Training and Rejection Sampling workflows.

Code Reference

Source Location

  • Repository: OpenRLHF
  • File: openrlhf/trainer/sft_trainer.py
  • Lines: L12-255 (class), L32-47 (__init__), L103-184 (fit)

Signature

class SFTTrainer(ABC):
    def __init__(
        self,
        model,                          # Actor: the model to train
        strategy,                       # DeepspeedStrategy
        optim: Optimizer,               # optimizer
        train_dataloader,               # training DataLoader
        eval_dataloader,                # evaluation DataLoader
        scheduler,                      # learning rate scheduler
        max_norm: float = 1,            # gradient clipping norm
        pretrain_mode: bool = False,    # if True, loss on all tokens
        batch_size: int = 1,            # batch size
        max_epochs: int = 2,            # number of training epochs
        tokenizer=None,                 # tokenizer for checkpointing
        save_hf_ckpt: bool = False,     # save HF format checkpoint
        disable_ds_ckpt: bool = False,  # disable DeepSpeed checkpoint
    ) -> None:

    def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
        """Run the full training loop."""

Import

from openrlhf.trainer import SFTTrainer

I/O Contract

Inputs

Name Type Required Description
model Actor Yes The policy model to train
strategy DeepspeedStrategy Yes Training strategy
optim Optimizer Yes Optimizer (FusedAdam or DeepSpeedCPUAdam)
train_dataloader DataLoader Yes Training data (from SFTDataset)
scheduler LRScheduler Yes Learning rate scheduler
max_norm float No Gradient clipping norm (default 1.0)

Outputs

Name Type Description
(side effect) None Model weights updated in-place
logs Dict Training metrics logged to W&B/TensorBoard
checkpoints Files Model checkpoints saved to disk

Usage Examples

from openrlhf.trainer import SFTTrainer

trainer = SFTTrainer(
    model=actor_model,
    strategy=strategy,
    optim=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    scheduler=scheduler,
    max_norm=args.max_norm,
    pretrain_mode=args.pretrain_mode,
    max_epochs=args.max_epochs,
    tokenizer=tokenizer,
)

trainer.fit(args, num_update_steps_per_epoch=num_update_steps_per_epoch)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment