Implementation:Hiyouga LLaMA Factory KTO Trainer
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, RLHF, Preference Optimization |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
Custom KTO trainer implementing Kahneman-Tversky preference optimization for unpaired preference data in LLaMA-Factory.
Description
CustomKTOTrainer extends TRL's KTOTrainer to implement the KTO alignment algorithm, which uses unpaired preference data where each sample is independently labeled as desirable or undesirable (unlike DPO which requires paired chosen/rejected examples). The trainer initializes directly via Trainer.__init__, overrides forward to handle multimodal inputs and compute log probabilities, concatenated_forward to split chosen/rejected samples using boolean kto_tags masks, and compute_reference_log_probs for reference model log probabilities. The get_batch_loss_metrics method computes the KTO loss with an optional SFT auxiliary loss (controlled by pref_ftx). The log method aggregates sum-based metrics across distributed processes and converts them to per-sample averages, handling the asymmetric nature of KTO where batch composition of chosen vs. rejected samples varies.
Usage
Instantiated by the KTO training workflow when stage="kto" is set in FinetuningArguments. Requires a dataset with kto_tags boolean field indicating desirable (True) vs. undesirable (False) samples, plus kl_ prefixed fields for KL estimation.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/train/kto/trainer.py
- Lines: 1-312
Signature
class CustomKTOTrainer(KTOTrainer):
def __init__(
self,
model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
disable_dropout: bool = True,
**kwargs,
): ...
def create_optimizer(self) -> "torch.optim.Optimizer": ...
def create_scheduler(self, num_training_steps, optimizer=None) -> "torch.optim.lr_scheduler.LRScheduler": ...
def forward(
self, model, batch, prefix: Literal["", "kl_"] = "",
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...
def concatenated_forward(
self, model, batch,
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor",
"torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...
def compute_reference_log_probs(
self, model, batch,
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...
def get_batch_loss_metrics(
self, model, batch,
) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: ...
def log(self, logs: dict[str, float], *args, **kwargs) -> None: ...
Import
from llamafactory.train.kto.trainer import CustomKTOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PreTrainedModel | Yes | The policy model to optimize |
| ref_model | PreTrainedModel or None | No | Reference model (None to use LoRA adapter disabling) |
| finetuning_args | FinetuningArguments | Yes | Contains pref_beta, kto_chosen_weight, kto_rejected_weight, pref_ftx |
| processor | ProcessorMixin or None | No | Processor to save in checkpoints |
| batch | dict[str, Tensor] | Yes (forward) | Must contain input_ids, attention_mask, labels, kto_tags, and kl_-prefixed fields |
| batch["kto_tags"] | torch.BoolTensor | Yes | Boolean mask: True for desirable, False for undesirable samples |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Scalar KTO training loss (nanmean over batch) |
| metrics | dict[str, float] | Sum-based metrics: rewards/chosen_sum, rewards/rejected_sum, logps/chosen_sum, logps/rejected_sum, count/chosen, count/rejected, kl |
| logged metrics | dict[str, float] | Averaged metrics after distributed reduction: rewards/chosen, rewards/rejected, rewards/margins |
Usage Examples
from llamafactory.train.kto.trainer import CustomKTOTrainer
# Instantiated by the KTO workflow (simplified example)
trainer = CustomKTOTrainer(
model=policy_model,
ref_model=reference_model,
finetuning_args=finetuning_args,
processor=processor,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
# Train with KTO loss
trainer.train()
# KTO uses kto_tags to split samples:
# - kto_tags=True: desirable samples (weighted by kto_chosen_weight)
# - kto_tags=False: undesirable samples (weighted by kto_rejected_weight)
# - kl_ prefixed inputs: used for KL divergence estimation