Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl SFTTrainer Init

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete API for initializing the SFT training pipeline by composing model, config, datasets, PEFT adapter, and data collator into a single SFTTrainer instance, provided by the TRL library.

Description

SFTTrainer is a subclass of BaseTrainer (which itself wraps transformers.Trainer) that adds SFT-specific initialization logic. Its __init__ method handles model loading (from string or object), processing class setup, PEFT wrapping (including QLoRA adapter dtype casting and DeepSpeed ZeRO-3 compatibility), data collator construction, dataset preparation (tokenization, chat template application, packing, truncation), loss function selection, and activation offloading setup.

Key internal behaviors:

  • Auto-detection of completion-only loss: If completion_only_loss is None in the config, the trainer checks whether the dataset has prompt/completion keys and enables completion-only loss automatically.
  • Vision-language model detection: If the dataset contains "image" or "images" keys, the trainer uses DataCollatorForVisionLanguageModeling for on-the-fly image processing.
  • PEFT + gradient checkpointing: The trainer enables input gradients on PEFT models to work around a Transformers bug, and forces reentrant checkpointing for PEFT + DeepSpeed ZeRO-3.

Usage

Use SFTTrainer as the central entry point for supervised fine-tuning. Pass the model (as string or object), training configuration, datasets, and optional PEFT config to get a fully configured trainer.

Code Reference

Source Location

  • Repository: TRL
  • File: trl/trainer/sft_trainer.py (lines 486-937, SFTTrainer.__init__)

Signature

class SFTTrainer(BaseTrainer):
    _tag_names = ["trl", "sft"]
    _name = "SFT"

    def __init__(
        self,
        model: "str | PreTrainedModel | PeftModel",
        args: SFTConfig | TrainingArguments | None = None,
        data_collator: DataCollator | None = None,
        train_dataset: Dataset | IterableDataset | None = None,
        eval_dataset: (
            Dataset
            | IterableDataset
            | dict[str, Dataset | IterableDataset]
            | None
        ) = None,
        processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
        compute_loss_func: Callable | None = None,
        compute_metrics: Callable[[EvalPrediction], dict] | None = None,
        callbacks: list[TrainerCallback] | None = None,
        optimizers: tuple = (None, None),
        optimizer_cls_and_kwargs: tuple | None = None,
        preprocess_logits_for_metrics: Callable | None = None,
        peft_config: "PeftConfig | None" = None,
        formatting_func: Callable | None = None,
    ):
        ...

Import

from trl import SFTTrainer, SFTConfig

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel | PeftModel Yes Model ID string, pretrained model instance, or PEFT-wrapped model
args TrainingArguments | None No Training configuration; defaults to SFTConfig with model name as output dir
data_collator None No Custom collator; auto-selected if None (language modeling or vision collator)
train_dataset IterableDataset | None No Training dataset in language modeling or prompt-completion format
eval_dataset IterableDataset | dict | None No Evaluation dataset(s)
processing_class ProcessorMixin | None No Tokenizer or processor; auto-loaded from model if None
compute_loss_func None No Custom loss function; auto-selected based on loss_type
compute_metrics None No Evaluation metrics function
callbacks None No Additional training callbacks
optimizers tuple No (optimizer, scheduler) tuple; defaults to AdamW with linear schedule
peft_config None No PEFT adapter configuration (e.g., LoraConfig); None for full fine-tuning
formatting_func None No Custom formatting function to convert examples to text before tokenization

Outputs

Name Type Description
trainer SFTTrainer Fully initialized trainer ready for .train(), .evaluate(), and .save_model()

Usage Examples

Minimal Usage (Model as String)

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    train_dataset=dataset,
)
trainer.train()

Full SFT with PEFT

from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from trl.trainer.utils import get_peft_config
from trl import ModelConfig

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", dtype="bfloat16")
dataset = load_dataset("trl-lib/Capybara")
model_args = ModelConfig(use_peft=True, lora_r=32, lora_alpha=16)

trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        output_dir="./sft-output",
        max_length=2048,
        packing=True,
        num_train_epochs=1,
        per_device_train_batch_size=4,
    ),
    train_dataset=dataset["train"],
    eval_dataset=dataset.get("test"),
    peft_config=get_peft_config(model_args),
)
trainer.train()

Prompt-Completion with Completion-Only Loss

from datasets import Dataset
from trl import SFTTrainer, SFTConfig

dataset = Dataset.from_dict({
    "prompt": ["Translate to French: Hello", "Translate to French: Goodbye"],
    "completion": [" Bonjour", " Au revoir"],
})

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=SFTConfig(
        output_dir="./output",
        completion_only_loss=True,
        max_length=512,
    ),
    train_dataset=dataset,
)
trainer.train()

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment