Implementation:Hiyouga LLaMA Factory DPO Trainer
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, RLHF, Preference Optimization |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
Custom DPO trainer supporting multiple preference loss types including DPO, ORPO, SimPO, and BCO for LLaMA-Factory.
Description
CustomDPOTrainer extends TRL's DPOTrainer to provide a unified preference optimization trainer in LLaMA-Factory. It manually initializes via Trainer.__init__ to bypass TRL defaults, and implements the core training loop for Direct Preference Optimization and its variants. The trainer computes chosen and rejected log probabilities from concatenated batches via concatenated_forward, obtains reference model log probabilities via compute_reference_log_probs (using either a separate reference model or LoRA adapter disabling), and dispatches to the appropriate loss function (DPO sigmoid/hinge/IPO, ORPO, SimPO, or BCO) via compute_preference_loss. It supports DeepSpeed, FSDP, and FSDP2 for the reference model, custom optimizers (GaLore, BAdam, LoRA+), and the LD-DPO verbose token weighting extension.
Usage
Instantiated by the DPO training workflow when stage="dpo" is set in FinetuningArguments. The trainer handles the full training loop including loss computation, metric logging, and checkpoint saving.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/train/dpo/trainer.py
- Lines: 1-349
Signature
class CustomDPOTrainer(DPOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
): ...
def create_optimizer(self) -> "torch.optim.Optimizer": ...
def create_scheduler(self, num_training_steps, optimizer=None) -> "torch.optim.lr_scheduler.LRScheduler": ...
def odds_ratio_loss(self, chosen_logps, rejected_logps) -> "torch.Tensor": ...
def simpo_loss(self, chosen_logps, rejected_logps) -> "torch.Tensor": ...
def bco_loss(self, chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps) -> "torch.Tensor": ...
def compute_preference_loss(
self, policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps,
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...
def concatenated_forward(
self, model, batch, is_ref_model=False,
) -> dict[str, "torch.Tensor"]: ...
def compute_reference_log_probs(
self, model, batch,
) -> tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ...
def get_batch_loss_metrics(
self, model, batch, train_eval="train",
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: ...
Import
from llamafactory.train.dpo.trainer import CustomDPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PreTrainedModel | Yes | The policy model to optimize |
| ref_model | PreTrainedModel or None | No | Reference model for KL-constrained losses (None for ORPO/SimPO) |
| finetuning_args | FinetuningArguments | Yes | Contains loss type, beta, gamma, ftx coefficient, etc. |
| processor | ProcessorMixin or None | No | Processor to save in checkpoints (for multimodal models) |
| batch | dict[str, Tensor] | Yes (forward) | Contains concatenated chosen+rejected input_ids, attention_mask, labels |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Scalar training loss averaged over the batch |
| metrics | dict[str, float] | Includes rewards/chosen, rewards/rejected, rewards/accuracies, rewards/margins, logps/chosen, logps/rejected, logits/chosen, logits/rejected |
Usage Examples
from llamafactory.train.dpo.trainer import CustomDPOTrainer
# Instantiated by the DPO workflow (simplified example)
trainer = CustomDPOTrainer(
model=policy_model,
ref_model=reference_model,
finetuning_args=finetuning_args,
processor=processor,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# Train the model
trainer.train()
# The trainer logs these metrics during training:
# - rewards/chosen: average reward for chosen responses
# - rewards/rejected: average reward for rejected responses
# - rewards/accuracies: fraction where chosen_reward > rejected_reward
# - rewards/margins: average reward margin (chosen - rejected)