Implementation:Unslothai Unsloth UnslothTrainer
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Optimization |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Wrapper around TRL's SFTTrainer with Unsloth-specific optimizations for memory-efficient supervised fine-tuning.
Description
UnslothTrainer subclasses TRL's SFTTrainer and adds:
- Separate embedding learning rate: Via UnslothTrainingArguments.embedding_learning_rate, allowing a different (usually lower) learning rate for embed_tokens and lm_head modules.
- Auto-packing: Automatic detection and application of padding-free sequence packing for efficient batching.
- TRL backward compatibility: Extensive patching (unsloth/trainer.py:L203-408) to maintain compatibility across TRL versions, handling API changes in SFTConfig, dataset_text_field, and data collator selection.
- VLM blocklist: Automatically disables packing for vision-language models where it would break image token alignment.
When Unsloth is imported, it also patches the global SFTTrainer class so that users can use either UnslothTrainer or SFTTrainer interchangeably with Unsloth optimizations applied.
Usage
Import and use as a drop-in replacement for TRL's SFTTrainer. The only additional parameter is embedding_learning_rate on UnslothTrainingArguments. All standard SFTTrainer/SFTConfig parameters are supported.
Code Reference
Source Location
- Repository: unsloth
- File: unsloth/trainer.py
- Lines: L133-137 (UnslothTrainingArguments), L182-200 (UnslothTrainer), L203-408 (TRL compatibility patching)
Signature
class UnslothTrainingArguments(TrainingArguments):
def __init__(self, embedding_learning_rate: float = None, *args, **kwargs):
"""
Args:
embedding_learning_rate (float): Separate learning rate for
embed_tokens and lm_head modules. Default None (uses main LR).
"""
class UnslothTrainer(SFTTrainer):
def create_optimizer(self):
"""
Overrides SFTTrainer optimizer creation to apply separate
embedding learning rate via _create_unsloth_optimizer.
"""
Import
from unsloth import UnslothTrainer, UnslothTrainingArguments
# Or equivalently after import unsloth:
from trl import SFTTrainer, SFTConfig # Patched by Unsloth
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PeftModel | Yes | LoRA model from get_peft_model |
| tokenizer | PreTrainedTokenizer | Yes | Configured tokenizer |
| train_dataset | Dataset | Yes | Training dataset (formatted with chat template) |
| args | UnslothTrainingArguments | Yes | Training hyperparameters |
| args.per_device_train_batch_size | int | No | Batch size per GPU |
| args.gradient_accumulation_steps | int | No | Gradient accumulation steps |
| args.learning_rate | float | No | Learning rate (default: 5e-5) |
| args.num_train_epochs | int | No | Number of training epochs |
| args.embedding_learning_rate | float | No | Separate LR for embeddings (Unsloth-specific) |
| data_collator | DataCollator | No | Data collator (auto-configured for VLMs) |
Outputs
| Name | Type | Description |
|---|---|---|
| trainer | UnslothTrainer | Configured trainer ready for .train() |
| trainer.train() returns | TrainOutput | Training metrics: global_step, training_loss, etc. |
Usage Examples
Basic SFT Training
from unsloth import FastLanguageModel, UnslothTrainer, UnslothTrainingArguments
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-3B-Instruct",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(model, r=16)
args = UnslothTrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
num_train_epochs=1,
learning_rate=2e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
logging_steps=10,
save_strategy="epoch",
)
trainer = UnslothTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=args,
)
trainer.train()
With Separate Embedding Learning Rate
# When training embed_tokens/lm_head via modules_to_save
model = FastLanguageModel.get_peft_model(
model, r=16, modules_to_save=["embed_tokens", "lm_head"],
)
args = UnslothTrainingArguments(
output_dir="./outputs",
learning_rate=2e-4,
embedding_learning_rate=5e-5, # Lower LR for embeddings
num_train_epochs=3,
)