Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hiyouga LLaMA Factory KTO Trainer

From Leeroopedia


Knowledge Sources
Domains Machine Learning, RLHF, Preference Optimization
Last Updated 2026-02-06 19:00 GMT

Overview

Custom KTO trainer implementing Kahneman-Tversky preference optimization for unpaired preference data in LLaMA-Factory.

Description

CustomKTOTrainer extends TRL's KTOTrainer to implement the KTO alignment algorithm, which uses unpaired preference data where each sample is independently labeled as desirable or undesirable (unlike DPO which requires paired chosen/rejected examples). The trainer initializes directly via Trainer.__init__, overrides forward to handle multimodal inputs and compute log probabilities, concatenated_forward to split chosen/rejected samples using boolean kto_tags masks, and compute_reference_log_probs for reference model log probabilities. The get_batch_loss_metrics method computes the KTO loss with an optional SFT auxiliary loss (controlled by pref_ftx). The log method aggregates sum-based metrics across distributed processes and converts them to per-sample averages, handling the asymmetric nature of KTO where batch composition of chosen vs. rejected samples varies.

Usage

Instantiated by the KTO training workflow when stage="kto" is set in FinetuningArguments. Requires a dataset with kto_tags boolean field indicating desirable (True) vs. undesirable (False) samples, plus kl_ prefixed fields for KL estimation.

Code Reference

Source Location

Signature

class CustomKTOTrainer(KTOTrainer):
    def __init__(
        self,
        model: Union["PreTrainedModel", torch.nn.Module],
        ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]],
        finetuning_args: "FinetuningArguments",
        processor: Optional["ProcessorMixin"],
        disable_dropout: bool = True,
        **kwargs,
    ): ...

    def create_optimizer(self) -> "torch.optim.Optimizer": ...
    def create_scheduler(self, num_training_steps, optimizer=None) -> "torch.optim.lr_scheduler.LRScheduler": ...

    def forward(
        self, model, batch, prefix: Literal["", "kl_"] = "",
    ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...

    def concatenated_forward(
        self, model, batch,
    ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor",
               "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...

    def compute_reference_log_probs(
        self, model, batch,
    ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ...

    def get_batch_loss_metrics(
        self, model, batch,
    ) -> tuple["torch.Tensor", dict[str, "torch.Tensor"]]: ...

    def log(self, logs: dict[str, float], *args, **kwargs) -> None: ...

Import

from llamafactory.train.kto.trainer import CustomKTOTrainer

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes The policy model to optimize
ref_model PreTrainedModel or None No Reference model (None to use LoRA adapter disabling)
finetuning_args FinetuningArguments Yes Contains pref_beta, kto_chosen_weight, kto_rejected_weight, pref_ftx
processor ProcessorMixin or None No Processor to save in checkpoints
batch dict[str, Tensor] Yes (forward) Must contain input_ids, attention_mask, labels, kto_tags, and kl_-prefixed fields
batch["kto_tags"] torch.BoolTensor Yes Boolean mask: True for desirable, False for undesirable samples

Outputs

Name Type Description
loss torch.Tensor Scalar KTO training loss (nanmean over batch)
metrics dict[str, float] Sum-based metrics: rewards/chosen_sum, rewards/rejected_sum, logps/chosen_sum, logps/rejected_sum, count/chosen, count/rejected, kl
logged metrics dict[str, float] Averaged metrics after distributed reduction: rewards/chosen, rewards/rejected, rewards/margins

Usage Examples

from llamafactory.train.kto.trainer import CustomKTOTrainer

# Instantiated by the KTO workflow (simplified example)
trainer = CustomKTOTrainer(
    model=policy_model,
    ref_model=reference_model,
    finetuning_args=finetuning_args,
    processor=processor,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

# Train with KTO loss
trainer.train()

# KTO uses kto_tags to split samples:
# - kto_tags=True: desirable samples (weighted by kto_chosen_weight)
# - kto_tags=False: undesirable samples (weighted by kto_rejected_weight)
# - kl_ prefixed inputs: used for KL divergence estimation

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment