Implementation:Huggingface Peft SFTTrainer Usage
Metadata
- Source: examples/sft/train.py:L123-139
- External Reference: TRL SFTTrainer Documentation
- Repository: huggingface/peft
- Type: Wrapper Doc
- Domains: NLP, Training
Overview
This implementation documents the usage of TRL's SFTTrainer as demonstrated in the PEFT examples. SFTTrainer is an external component from the TRL library that provides a managed training loop for supervised fine-tuning with built-in PEFT integration. When passed a peft_config, SFTTrainer internally calls get_peft_model to wrap the base model -- users do not need to call it manually.
Imports
from trl import SFTTrainer, SFTConfig
from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
Core API
SFTTrainer Constructor
trainer = SFTTrainer(
model=model, # AutoModelForCausalLM (optionally quantized)
processing_class=tokenizer, # AutoTokenizer instance
args=training_args, # SFTConfig (extends TrainingArguments)
train_dataset=train_dataset, # HF Dataset for training
eval_dataset=eval_dataset, # HF Dataset for evaluation
peft_config=peft_config, # LoraConfig -- SFTTrainer calls get_peft_model internally
)
Training Execution
trainer.train(resume_from_checkpoint=checkpoint)
Model Saving
# For FSDP: set full state dict type before saving
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
Full Usage Pattern
The following pattern is extracted from examples/sft/train.py:
from trl import SFTTrainer, SFTConfig
from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
# 1. Parse arguments (ModelArguments, DataTrainingArguments, SFTConfig)
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# 2. Set seed for reproducibility
set_seed(training_args.seed)
# 3. Load model (optionally quantized with BitsAndBytesConfig)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
quantization_config=bnb_config, # optional 4-bit or 8-bit
trust_remote_code=True,
)
# 4. Configure LoRA
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=64,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"down_proj", "up_proj", "gate_proj"],
)
# 5. Load tokenizer and prepare datasets
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
train_dataset, eval_dataset = create_datasets(tokenizer, data_args, training_args)
# 6. Configure gradient checkpointing
model.config.use_cache = not training_args.gradient_checkpointing
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
# 7. Create trainer -- peft_config triggers internal get_peft_model call
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
# 8. Print trainable parameters for verification
if hasattr(trainer.model, "print_trainable_parameters"):
trainer.model.print_trainable_parameters()
# 9. Train with optional checkpoint resumption
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
# 10. Save final model (PEFT-aware: saves only adapter weights)
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
Key Parameters
ModelArguments (from examples/sft/train.py)
| Parameter | Type | Default | Description |
|---|---|---|---|
model_name_or_path |
str | (required) | Pretrained model path or HF Hub identifier |
chat_template_format |
str | "none" | Chat template: "chatml", "zephyr", or "none" |
lora_alpha |
int | 16 | LoRA scaling factor |
lora_dropout |
float | 0.1 | Dropout for LoRA layers |
lora_r |
int | 64 | LoRA rank |
lora_target_modules |
str | "q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj" | Comma-separated target module names |
use_peft_lora |
bool | False | Enable PEFT LoRA training |
use_4bit_quantization |
bool | False | Load model in 4-bit precision |
use_8bit_quantization |
bool | False | Load model in 8-bit precision |
use_flash_attn |
bool | False | Enable Flash Attention 2 |
SFTConfig (extends TrainingArguments)
SFTConfig inherits all standard TrainingArguments parameters and adds SFT-specific options including max_length for sequence truncation and dataset formatting options.
Design Decisions
- SFTTrainer manages PEFT lifecycle: Unlike manual PEFT usage, users pass a raw (non-wrapped) model and a
peft_configto SFTTrainer, which handles callingget_peft_modelinternally. This avoids double-wrapping issues. - Chat template support: The example supports ChatML and Zephyr templates, with special tokens added to the tokenizer and embeddings resized accordingly. When using "none", the dataset is expected to be pre-formatted.
- FSDP compatibility: When using Fully Sharded Data Parallel, the state dict type must be set to
FULL_STATE_DICTbefore saving to ensure all adapter weights are gathered correctly. - Disable cache during training:
model.config.use_cacheis set toFalsewhen gradient checkpointing is enabled, as KV caching is incompatible with gradient checkpointing.