Implementation:Huggingface Trl SFTTrainer Train
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete training execution methods on the SFTTrainer that perform the optimization loop with custom loss computation, token accuracy tracking, and activation offloading, provided by the TRL library (wrapping transformers.Trainer).
Description
The SFTTrainer inherits .train() from transformers.Trainer and overrides two methods:
compute_loss()-- Delegates to the parent Trainer for the actual loss value, then computes additional metrics: token accuracy (via argmax comparison), Shannon entropy (via chunked softmax), total training tokens, and optional auxiliary loss for MoE models. When using Liger kernel, token accuracy is obtained directly from the kernel output.
training_step()-- Wraps the parent's training step in an activation offloading context manager (when enabled), which moves activations to CPU during forward and retrieves them during backward.
The .log() method is also overridden to merge the custom metrics (mean_token_accuracy, entropy, num_tokens, aux_loss) into the standard Trainer log output.
Usage
Call trainer.train() after initializing the SFTTrainer. The method handles the full training loop including data loading, gradient computation, optimization, logging, checkpointing, and evaluation.
Code Reference
Source Location
- Repository: TRL
- File:
trl/trainer/sft_trainer.py(lines 1171-1290,compute_lossandtraining_step) - File:
trl/trainer/sft_trainer.py(lines 1292-1303,log)
Signature
class SFTTrainer(BaseTrainer):
# Inherited from transformers.Trainer
def train(
self,
resume_from_checkpoint: str | bool | None = None,
) -> TrainOutput:
...
# SFT-specific override
def compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
):
"""
Computes the training loss plus metrics:
- mean_token_accuracy
- entropy (Shannon entropy of output distribution)
- num_tokens (total training tokens seen)
- aux_loss (for MoE models)
"""
...
# SFT-specific override
def training_step(self, *args, **kwargs):
"""Wraps parent training_step with activation offloading context."""
with self.maybe_activation_offload_context:
return super().training_step(*args, **kwargs)
# SFT-specific override
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""Merges custom SFT metrics into the standard log output."""
...
Import
# Methods are called on an SFTTrainer instance; no separate import needed
from trl import SFTTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| resume_from_checkpoint | bool | None | No | Path to checkpoint directory, True to auto-detect latest, or None to train from scratch
|
| inputs (to compute_loss) | dict |
Yes (internal) | Batch dictionary with input_ids, labels, attention_mask (or position_ids for padding-free)
|
Outputs
| Name | Type | Description |
|---|---|---|
| train_output | TrainOutput |
Named tuple with global_step, training_loss, and metrics dict
|
| metrics (logged) | dict |
Includes mean_token_accuracy, entropy, num_tokens, and optionally aux_loss
|
Usage Examples
Basic Training
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=SFTConfig(
output_dir="./output",
num_train_epochs=1,
per_device_train_batch_size=4,
logging_steps=10,
),
train_dataset=dataset,
)
result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
print(f"Global steps: {result.global_step}")
Resume from Checkpoint
result = trainer.train(resume_from_checkpoint="./output/checkpoint-500")
Training with Activation Offloading
trainer = SFTTrainer(
model="meta-llama/Llama-3.1-8B",
args=SFTConfig(
output_dir="./output",
activation_offloading=True,
gradient_checkpointing=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
),
train_dataset=dataset,
)
trainer.train()
Training with DFT Loss
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=SFTConfig(
output_dir="./output",
loss_type="dft",
),
train_dataset=dataset,
)
trainer.train()