Implementation:Huggingface Trl DPOTrainer Init Train
| Knowledge Sources | |
|---|---|
| Domains | NLP, RLHF |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete tool for initializing the DPO trainer and executing the preference optimization training loop, provided by the TRL library.
Description
DPOTrainer is the main trainer class for Direct Preference Optimization, inheriting from BaseTrainer (which extends Hugging Face's Trainer). Its __init__ method orchestrates:
- Argument resolution: If no
DPOConfigis provided, creates a default one. - Model preparation: Handles string-based model loading via
create_model_from_path, validates that model and ref_model are not the same object, and resolves processing class (tokenizer or processor). - PEFT wrapping: If a
peft_configis provided, wraps the model withget_peft_model, handling quantized model preparation and bf16 casting. - Reference model resolution: Uses the explicit ref_model, or sets to None for PEFT/precompute scenarios, or creates a deep copy via
create_reference_model. - Dropout disabling: Disables dropout in both policy and reference models for stable DPO training.
- Data collator setup: Defaults to
DataCollatorForPreferenceif none is provided. - Dataset preparation: Applies prompt extraction, chat template, and tokenization to train and eval datasets.
- Distributed training preparation: Prepares the reference model for DeepSpeed, FSDP, or standard Accelerate.
- TR-DPO callback: Registers
SyncRefModelCallbackif reference model synchronization is enabled.
The compute_loss method is the training loop entry point called by the parent Trainer:
- Calls
get_batch_loss_metricswhich runs the concatenated forward pass, computes reference log probs, and calculates the DPO loss for all configured loss types - Returns the scalar loss (and optionally metrics) for gradient computation
The train method is inherited from the Transformers Trainer and handles the full training loop with gradient accumulation, mixed precision, distributed training, checkpointing, and logging.
Usage
Use DPOTrainer when:
- Training a model with DPO from a preference dataset
- You need the full Hugging Face Trainer ecosystem (logging, checkpointing, distributed training)
- Combining DPO with PEFT/LoRA for parameter-efficient alignment
- Running multi-loss (MPO) training with combined objectives
Code Reference
Source Location
- Repository: TRL
- File:
trl/trainer/dpo_trainer.py - __init__: lines 271-569
- compute_loss: lines 1830-1851
- get_batch_loss_metrics: lines 1742-1828
- dpo_loss: lines 1039-1260
- concatenated_forward: lines 1494-1740
Signature
class DPOTrainer(BaseTrainer):
_tag_names = ["trl", "dpo"]
_name = "DPO"
def __init__(
self,
model: str | nn.Module | PreTrainedModel,
ref_model: PreTrainedModel | nn.Module | str | None = None,
args: DPOConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: PeftConfig | None = None,
):
def compute_loss(
self,
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
def dpo_loss(
self,
chosen_logps: torch.FloatTensor,
rejected_logps: torch.FloatTensor,
ref_chosen_logps: torch.FloatTensor,
ref_rejected_logps: torch.FloatTensor,
loss_type: str = "sigmoid",
model_output: dict[str, torch.FloatTensor] = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
def train(self) -> TrainOutput: # inherited from Trainer
Import
from trl import DPOTrainer, DPOConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | str or PreTrainedModel |
Yes | Policy model: a model ID string or a pretrained model instance |
| ref_model | PreTrainedModel or None |
No | Reference model; None when using PEFT or precomputed log probs |
| args | DPOConfig |
No | Training configuration with DPO-specific hyperparameters |
| train_dataset | Dataset or IterableDataset |
Yes | Preference dataset with prompt, chosen, rejected columns |
| eval_dataset | Dataset or dict or None |
No | Evaluation dataset(s) with the same format as train_dataset |
| processing_class | PreTrainedTokenizerBase or ProcessorMixin or None |
No | Tokenizer or processor; auto-loaded from model if None |
| peft_config | PeftConfig or None |
No | PEFT configuration for parameter-efficient training |
| callbacks | list[TrainerCallback] or None |
No | Custom training callbacks |
| optimizers | tuple |
No (default: (None, None)) | Custom optimizer and scheduler; defaults to AdamW with linear warmup |
Outputs
| Name | Type | Description |
|---|---|---|
| DPOTrainer instance | DPOTrainer |
Initialized trainer ready for .train() and .evaluate()
|
| TrainOutput | TrainOutput |
Training results including global_step, training_loss, and metrics (returned by .train())
|
| Training metrics | dict[str, float] |
Per-step metrics: loss, rewards/chosen, rewards/rejected, rewards/margins, rewards/accuracies, logps/chosen, logps/rejected, logits/chosen, logits/rejected |
Usage Examples
# Example 1: Basic DPO training (full fine-tuning)
from datasets import load_dataset
from transformers import AutoModelForCausalLM
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized")
training_args = DPOConfig(
output_dir="Qwen2-0.5B-DPO",
beta=0.1,
loss_type=["sigmoid"],
learning_rate=5e-7,
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
trainer.train()
# Example 2: DPO with LoRA (parameter-efficient)
from peft import LoraConfig
from trl import DPOConfig, DPOTrainer
peft_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules="all-linear",
task_type="CAUSAL_LM",
)
training_args = DPOConfig(
output_dir="Qwen2-0.5B-DPO-LoRA",
beta=0.1,
learning_rate=5e-6,
num_train_epochs=1,
)
trainer = DPOTrainer(
model=model,
ref_model=None, # implicit reference via adapter disabling
args=training_args,
train_dataset=dataset["train"],
peft_config=peft_config,
)
trainer.train()
# Example 3: Multi-loss (MPO-style) training
training_args = DPOConfig(
output_dir="Qwen2-0.5B-MPO",
beta=0.1,
loss_type=["sigmoid", "sft"],
loss_weights=[0.8, 1.0],
learning_rate=5e-7,
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset["train"],
)
trainer.train()
Related Pages
Implements Principle
- Principle:Huggingface_Trl_DPO_Training
- Environment:Huggingface_Trl_Python_Core_Dependencies
- Environment:Huggingface_Trl_PEFT_LoRA_Environment
- Environment:Huggingface_Trl_DeepSpeed_Environment
- Heuristic:Huggingface_Trl_Disable_Dropout_For_RL_Training
- Heuristic:Huggingface_Trl_Distributed_Device_Map_Override