Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Trl SFTTrainer Train

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete training execution methods on the SFTTrainer that perform the optimization loop with custom loss computation, token accuracy tracking, and activation offloading, provided by the TRL library (wrapping transformers.Trainer).

Description

The SFTTrainer inherits .train() from transformers.Trainer and overrides two methods:

  • compute_loss() -- Delegates to the parent Trainer for the actual loss value, then computes additional metrics: token accuracy (via argmax comparison), Shannon entropy (via chunked softmax), total training tokens, and optional auxiliary loss for MoE models. When using Liger kernel, token accuracy is obtained directly from the kernel output.
  • training_step() -- Wraps the parent's training step in an activation offloading context manager (when enabled), which moves activations to CPU during forward and retrieves them during backward.

The .log() method is also overridden to merge the custom metrics (mean_token_accuracy, entropy, num_tokens, aux_loss) into the standard Trainer log output.

Usage

Call trainer.train() after initializing the SFTTrainer. The method handles the full training loop including data loading, gradient computation, optimization, logging, checkpointing, and evaluation.

Code Reference

Source Location

  • Repository: TRL
  • File: trl/trainer/sft_trainer.py (lines 1171-1290, compute_loss and training_step)
  • File: trl/trainer/sft_trainer.py (lines 1292-1303, log)

Signature

class SFTTrainer(BaseTrainer):
    # Inherited from transformers.Trainer
    def train(
        self,
        resume_from_checkpoint: str | bool | None = None,
    ) -> TrainOutput:
        ...

    # SFT-specific override
    def compute_loss(
        self,
        model,
        inputs,
        return_outputs=False,
        num_items_in_batch=None,
    ):
        """
        Computes the training loss plus metrics:
        - mean_token_accuracy
        - entropy (Shannon entropy of output distribution)
        - num_tokens (total training tokens seen)
        - aux_loss (for MoE models)
        """
        ...

    # SFT-specific override
    def training_step(self, *args, **kwargs):
        """Wraps parent training_step with activation offloading context."""
        with self.maybe_activation_offload_context:
            return super().training_step(*args, **kwargs)

    # SFT-specific override
    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        """Merges custom SFT metrics into the standard log output."""
        ...

Import

# Methods are called on an SFTTrainer instance; no separate import needed
from trl import SFTTrainer

I/O Contract

Inputs

Name Type Required Description
resume_from_checkpoint bool | None No Path to checkpoint directory, True to auto-detect latest, or None to train from scratch
inputs (to compute_loss) dict Yes (internal) Batch dictionary with input_ids, labels, attention_mask (or position_ids for padding-free)

Outputs

Name Type Description
train_output TrainOutput Named tuple with global_step, training_loss, and metrics dict
metrics (logged) dict Includes mean_token_accuracy, entropy, num_tokens, and optionally aux_loss

Usage Examples

Basic Training

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=SFTConfig(
        output_dir="./output",
        num_train_epochs=1,
        per_device_train_batch_size=4,
        logging_steps=10,
    ),
    train_dataset=dataset,
)

result = trainer.train()
print(f"Training loss: {result.training_loss:.4f}")
print(f"Global steps: {result.global_step}")

Resume from Checkpoint

result = trainer.train(resume_from_checkpoint="./output/checkpoint-500")

Training with Activation Offloading

trainer = SFTTrainer(
    model="meta-llama/Llama-3.1-8B",
    args=SFTConfig(
        output_dir="./output",
        activation_offloading=True,
        gradient_checkpointing=True,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=16,
    ),
    train_dataset=dataset,
)
trainer.train()

Training with DFT Loss

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=SFTConfig(
        output_dir="./output",
        loss_type="dft",
    ),
    train_dataset=dataset,
)
trainer.train()

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment