Implementation:Huggingface Trl SFTTrainer Init
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete API for initializing the SFT training pipeline by composing model, config, datasets, PEFT adapter, and data collator into a single SFTTrainer instance, provided by the TRL library.
Description
SFTTrainer is a subclass of BaseTrainer (which itself wraps transformers.Trainer) that adds SFT-specific initialization logic. Its __init__ method handles model loading (from string or object), processing class setup, PEFT wrapping (including QLoRA adapter dtype casting and DeepSpeed ZeRO-3 compatibility), data collator construction, dataset preparation (tokenization, chat template application, packing, truncation), loss function selection, and activation offloading setup.
Key internal behaviors:
- Auto-detection of completion-only loss: If
completion_only_lossis None in the config, the trainer checks whether the dataset hasprompt/completionkeys and enables completion-only loss automatically. - Vision-language model detection: If the dataset contains
"image"or"images"keys, the trainer usesDataCollatorForVisionLanguageModelingfor on-the-fly image processing. - PEFT + gradient checkpointing: The trainer enables input gradients on PEFT models to work around a Transformers bug, and forces reentrant checkpointing for PEFT + DeepSpeed ZeRO-3.
Usage
Use SFTTrainer as the central entry point for supervised fine-tuning. Pass the model (as string or object), training configuration, datasets, and optional PEFT config to get a fully configured trainer.
Code Reference
Source Location
- Repository: TRL
- File:
trl/trainer/sft_trainer.py(lines 486-937,SFTTrainer.__init__)
Signature
class SFTTrainer(BaseTrainer):
_tag_names = ["trl", "sft"]
_name = "SFT"
def __init__(
self,
model: "str | PreTrainedModel | PeftModel",
args: SFTConfig | TrainingArguments | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: (
Dataset
| IterableDataset
| dict[str, Dataset | IterableDataset]
| None
) = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
compute_loss_func: Callable | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple = (None, None),
optimizer_cls_and_kwargs: tuple | None = None,
preprocess_logits_for_metrics: Callable | None = None,
peft_config: "PeftConfig | None" = None,
formatting_func: Callable | None = None,
):
...
Import
from trl import SFTTrainer, SFTConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PreTrainedModel | PeftModel | Yes | Model ID string, pretrained model instance, or PEFT-wrapped model |
| args | TrainingArguments | None | No | Training configuration; defaults to SFTConfig with model name as output dir
|
| data_collator | None | No | Custom collator; auto-selected if None (language modeling or vision collator) |
| train_dataset | IterableDataset | None | No | Training dataset in language modeling or prompt-completion format |
| eval_dataset | IterableDataset | dict | None | No | Evaluation dataset(s) |
| processing_class | ProcessorMixin | None | No | Tokenizer or processor; auto-loaded from model if None |
| compute_loss_func | None | No | Custom loss function; auto-selected based on loss_type
|
| compute_metrics | None | No | Evaluation metrics function |
| callbacks | None | No | Additional training callbacks |
| optimizers | tuple |
No | (optimizer, scheduler) tuple; defaults to AdamW with linear schedule |
| peft_config | None | No | PEFT adapter configuration (e.g., LoraConfig); None for full fine-tuning |
| formatting_func | None | No | Custom formatting function to convert examples to text before tokenization |
Outputs
| Name | Type | Description |
|---|---|---|
| trainer | SFTTrainer |
Fully initialized trainer ready for .train(), .evaluate(), and .save_model()
|
Usage Examples
Minimal Usage (Model as String)
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
Full SFT with PEFT
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from trl.trainer.utils import get_peft_config
from trl import ModelConfig
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", dtype="bfloat16")
dataset = load_dataset("trl-lib/Capybara")
model_args = ModelConfig(use_peft=True, lora_r=32, lora_alpha=16)
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir="./sft-output",
max_length=2048,
packing=True,
num_train_epochs=1,
per_device_train_batch_size=4,
),
train_dataset=dataset["train"],
eval_dataset=dataset.get("test"),
peft_config=get_peft_config(model_args),
)
trainer.train()
Prompt-Completion with Completion-Only Loss
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
dataset = Dataset.from_dict({
"prompt": ["Translate to French: Hello", "Translate to French: Goodbye"],
"completion": [" Bonjour", " Au revoir"],
})
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=SFTConfig(
output_dir="./output",
completion_only_loss=True,
max_length=512,
),
train_dataset=dataset,
)
trainer.train()