Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Peft SFTTrainer Usage

From Leeroopedia


Metadata

Overview

This implementation documents the usage of TRL's SFTTrainer as demonstrated in the PEFT examples. SFTTrainer is an external component from the TRL library that provides a managed training loop for supervised fine-tuning with built-in PEFT integration. When passed a peft_config, SFTTrainer internally calls get_peft_model to wrap the base model -- users do not need to call it manually.

Imports

from trl import SFTTrainer, SFTConfig
from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig

Core API

SFTTrainer Constructor

trainer = SFTTrainer(
    model=model,                  # AutoModelForCausalLM (optionally quantized)
    processing_class=tokenizer,   # AutoTokenizer instance
    args=training_args,           # SFTConfig (extends TrainingArguments)
    train_dataset=train_dataset,  # HF Dataset for training
    eval_dataset=eval_dataset,    # HF Dataset for evaluation
    peft_config=peft_config,      # LoraConfig -- SFTTrainer calls get_peft_model internally
)

Training Execution

trainer.train(resume_from_checkpoint=checkpoint)

Model Saving

# For FSDP: set full state dict type before saving
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()

Full Usage Pattern

The following pattern is extracted from examples/sft/train.py:

from trl import SFTTrainer, SFTConfig
from transformers import HfArgumentParser, set_seed, AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig

# 1. Parse arguments (ModelArguments, DataTrainingArguments, SFTConfig)
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, SFTConfig))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

# 2. Set seed for reproducibility
set_seed(training_args.seed)

# 3. Load model (optionally quantized with BitsAndBytesConfig)
model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    quantization_config=bnb_config,  # optional 4-bit or 8-bit
    trust_remote_code=True,
)

# 4. Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                     "down_proj", "up_proj", "gate_proj"],
)

# 5. Load tokenizer and prepare datasets
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
train_dataset, eval_dataset = create_datasets(tokenizer, data_args, training_args)

# 6. Configure gradient checkpointing
model.config.use_cache = not training_args.gradient_checkpointing
if training_args.gradient_checkpointing:
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}

# 7. Create trainer -- peft_config triggers internal get_peft_model call
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
)

# 8. Print trainable parameters for verification
if hasattr(trainer.model, "print_trainable_parameters"):
    trainer.model.print_trainable_parameters()

# 9. Train with optional checkpoint resumption
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)

# 10. Save final model (PEFT-aware: saves only adapter weights)
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()

Key Parameters

ModelArguments (from examples/sft/train.py)

Parameter Type Default Description
model_name_or_path str (required) Pretrained model path or HF Hub identifier
chat_template_format str "none" Chat template: "chatml", "zephyr", or "none"
lora_alpha int 16 LoRA scaling factor
lora_dropout float 0.1 Dropout for LoRA layers
lora_r int 64 LoRA rank
lora_target_modules str "q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj" Comma-separated target module names
use_peft_lora bool False Enable PEFT LoRA training
use_4bit_quantization bool False Load model in 4-bit precision
use_8bit_quantization bool False Load model in 8-bit precision
use_flash_attn bool False Enable Flash Attention 2

SFTConfig (extends TrainingArguments)

SFTConfig inherits all standard TrainingArguments parameters and adds SFT-specific options including max_length for sequence truncation and dataset formatting options.

Design Decisions

  • SFTTrainer manages PEFT lifecycle: Unlike manual PEFT usage, users pass a raw (non-wrapped) model and a peft_config to SFTTrainer, which handles calling get_peft_model internally. This avoids double-wrapping issues.
  • Chat template support: The example supports ChatML and Zephyr templates, with special tokens added to the tokenizer and embeddings resized accordingly. When using "none", the dataset is expected to be pre-formatted.
  • FSDP compatibility: When using Fully Sharded Data Parallel, the state dict type must be set to FULL_STATE_DICT before saving to ensure all adapter weights are gathered correctly.
  • Disable cache during training: model.config.use_cache is set to False when gradient checkpointing is enabled, as KV caching is incompatible with gradient checkpointing.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment