Implementation:OpenRLHF OpenRLHF SFTTrainer
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for supervised fine-tuning of language models provided by OpenRLHF.
Description
The SFTTrainer class implements the complete SFT training loop with DeepSpeed integration. It wraps the model forward pass (computing per-token log-probabilities via Actor.forward), applies the SFTLoss with loss masking, handles gradient accumulation and clipping, manages evaluation, checkpointing (both DeepSpeed and HuggingFace formats), and logging to W&B or TensorBoard.
Key features: checkpoint recovery via consumed_samples tracking, MoE auxiliary loss support, ring attention integration, and configurable evaluation/save frequencies.
Usage
Instantiate after creating the model, optimizer, scheduler, and dataloaders. Call fit() to run the full training loop. Used in SFT Training and Rejection Sampling workflows.
Code Reference
Source Location
- Repository: OpenRLHF
- File: openrlhf/trainer/sft_trainer.py
- Lines: L12-255 (class), L32-47 (__init__), L103-184 (fit)
Signature
class SFTTrainer(ABC):
def __init__(
self,
model, # Actor: the model to train
strategy, # DeepspeedStrategy
optim: Optimizer, # optimizer
train_dataloader, # training DataLoader
eval_dataloader, # evaluation DataLoader
scheduler, # learning rate scheduler
max_norm: float = 1, # gradient clipping norm
pretrain_mode: bool = False, # if True, loss on all tokens
batch_size: int = 1, # batch size
max_epochs: int = 2, # number of training epochs
tokenizer=None, # tokenizer for checkpointing
save_hf_ckpt: bool = False, # save HF format checkpoint
disable_ds_ckpt: bool = False, # disable DeepSpeed checkpoint
) -> None:
def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
"""Run the full training loop."""
Import
from openrlhf.trainer import SFTTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | Actor | Yes | The policy model to train |
| strategy | DeepspeedStrategy | Yes | Training strategy |
| optim | Optimizer | Yes | Optimizer (FusedAdam or DeepSpeedCPUAdam) |
| train_dataloader | DataLoader | Yes | Training data (from SFTDataset) |
| scheduler | LRScheduler | Yes | Learning rate scheduler |
| max_norm | float | No | Gradient clipping norm (default 1.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Model weights updated in-place |
| logs | Dict | Training metrics logged to W&B/TensorBoard |
| checkpoints | Files | Model checkpoints saved to disk |
Usage Examples
from openrlhf.trainer import SFTTrainer
trainer = SFTTrainer(
model=actor_model,
strategy=strategy,
optim=optimizer,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
scheduler=scheduler,
max_norm=args.max_norm,
pretrain_mode=args.pretrain_mode,
max_epochs=args.max_epochs,
tokenizer=tokenizer,
)
trainer.fit(args, num_update_steps_per_epoch=num_update_steps_per_epoch)