Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Transformers Trainer Train For PEFT

From Leeroopedia
Knowledge Sources
Domains Parameter_Efficient_Fine_Tuning, NLP, Model_Training
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for training PEFT adapter parameters using the Transformers Trainer, which wraps the standard training loop with PEFT-aware checkpoint management, FSDP support, and loss computation.

Description

The Trainer.train() method is the primary entry point for training models in Transformers, and it works transparently with PEFT adapter models. When a model with injected adapters is passed to the Trainer, the training loop optimizes only the trainable adapter parameters while the base model weights remain frozen.

The Trainer provides PEFT-specific behaviors:

  • PEFT model detection: Uses the internal _is_peft_model() check to determine if the model has adapters. This affects checkpoint loading strictness (non-strict loading for PEFT models with DeepSpeed) and label smoothing logic.
  • FSDP integration: When using FSDP with PEFT models, the Trainer calls update_fsdp_plugin_peft() to configure the FSDP wrapping policy correctly for adapter layers.
  • Checkpoint saving: During training, intermediate checkpoints save only adapter weights via the PEFT-aware save_pretrained mechanism, keeping checkpoint sizes small.
  • Loss computation: The compute_loss method correctly unwraps the PEFT model (via unwrapped_model.base_model.model._get_name()) to identify the model type for label smoothing.
  • Gradient accumulation: Standard gradient accumulation works correctly with PEFT models since only adapter parameters accumulate gradients.

The training loop follows the standard Trainer pattern: iterate over epochs, compute forward/backward passes, clip gradients, step the optimizer and scheduler, and handle logging/evaluation/checkpointing at appropriate intervals.

Usage

Use the Trainer with PEFT models when you want to:

  • Train LoRA or other PEFT adapters with standard Transformers training infrastructure
  • Leverage built-in support for mixed precision, gradient accumulation, distributed training, and logging
  • Use callbacks (e.g., early stopping, TensorBoard) during adapter training
  • Resume adapter training from a checkpoint

Code Reference

Source Location

  • Repository: transformers
  • File: src/transformers/trainer.py (lines 1302-1901)

Signature

def train(
    self,
    resume_from_checkpoint: str | bool | None = None,
    trial: "optuna.Trial | dict[str, Any] | None" = None,
    ignore_keys_for_eval: list[str] | None = None,
) -> TrainOutput

Import

from transformers import Trainer, TrainingArguments

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes (at Trainer init) The model with injected PEFT adapters. Passed to the Trainer constructor, not to train() directly.
args TrainingArguments Yes (at Trainer init) Training hyperparameters including learning rate, batch size, number of epochs, output directory, etc.
train_dataset Dataset Yes (at Trainer init) The training dataset. Should produce tokenized examples with input_ids, attention_mask, and labels.
eval_dataset Dataset No (at Trainer init) Optional evaluation dataset for periodic evaluation during training.
resume_from_checkpoint str or bool or None No Path to a checkpoint directory to resume from, or True to resume from the latest checkpoint. Default: None.
trial optuna.Trial or dict No For hyperparameter search integration. Default: None.
ignore_keys_for_eval list[str] No Keys in model output to exclude from evaluation gathering. Default: None.

Outputs

Name Type Description
result TrainOutput A named tuple containing global_step (total optimization steps), training_loss (average training loss), and metrics (a dictionary with training speed, loss, and other metrics).

Usage Examples

Basic Usage: Train LoRA Adapter

from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig

# Load model and inject adapter
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="auto")
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM")
model.add_adapter(lora_config)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./lora-output",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    save_strategy="epoch",
    logging_steps=10,
)

# Create Trainer and train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

result = trainer.train()

QLoRA Training with Gradient Checkpointing

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments
from peft import LoraConfig
import torch

# Load quantized base model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto",
)

# Inject LoRA adapter
lora_config = LoraConfig(r=64, lora_alpha=16, target_modules=None, task_type="CAUSAL_LM")
model.add_adapter(lora_config)

training_args = TrainingArguments(
    output_dir="./qlora-output",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    bf16=True,
    save_strategy="steps",
    save_steps=100,
)

trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
result = trainer.train()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment