Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenRLHF OpenRLHF KTOTrainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, Alignment, Training
Last Updated 2026-02-07 10:40 GMT

Overview

Concrete tool for training language models using Kahneman-Tversky Optimization (KTO) with unpaired preference data.

Description

The KTOTrainer class implements the training loop for Kahneman-Tversky Optimization, an alignment method that works with unpaired preference data (individual samples labeled as desirable or undesirable, rather than chosen/rejected pairs). It computes policy and reference model log-probabilities, applies the KTO loss function with a KL divergence regularization term, and supports gradient accumulation, distributed training via DeepSpeed, and WandB/TensorBoard logging.

Usage

Use this trainer when aligning a language model with human preferences using unpaired feedback data. Unlike DPO which requires paired (chosen, rejected) examples, KTO can learn from independently labeled samples, making it suitable when paired preference data is unavailable.

Code Reference

Source Location

Signature

class KTOTrainer(ABC):
    def __init__(
        self,
        model,
        ref_model,
        strategy,
        tokenizer,
        optim: Optimizer,
        train_dataloader,
        eval_dataloader,
        scheduler,
        max_norm=0.5,
        beta=0.01,
        max_epochs: int = 2,
        save_hf_ckpt: bool = False,
        disable_ds_ckpt: bool = False,
    ) -> None: ...

    def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None): ...
    def evaluate(self, eval_dataloader, steps=0): ...
    def compute_model_logps_with_KL(self, model, input_ids, attention_mask, labels, prompt_id_lens): ...
    def compute_model_logps(self, model, input_ids, attention_mask, labels, prompt_id_lens): ...

Import

from openrlhf.trainer.kto_trainer import KTOTrainer

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes Policy model to train (Actor)
ref_model nn.Module Yes Frozen reference model for KL estimation
strategy DeepspeedStrategy Yes Distributed training strategy
tokenizer PreTrainedTokenizer Yes Tokenizer for text processing
optim Optimizer Yes PyTorch optimizer
train_dataloader DataLoader Yes Training data (UnpairedPreferenceDataset)
eval_dataloader DataLoader No Evaluation data
scheduler LRScheduler Yes Learning rate scheduler
beta float No KTO regularization coefficient (default: 0.01)
max_norm float No Gradient clipping max norm (default: 0.5)

Outputs

Name Type Description
fit() None Trains model in-place, saves checkpoints and logs metrics
evaluate() None Computes eval loss, chosen/rejected rewards; logs to WandB/TensorBoard

Usage Examples

Creating and Running KTO Training

from openrlhf.trainer.kto_trainer import KTOTrainer

trainer = KTOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    strategy=strategy,
    optim=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    scheduler=scheduler,
    max_norm=1.0,
    beta=0.01,
    max_epochs=1,
    save_hf_ckpt=False,
    disable_ds_ckpt=False,
)

# Run training
trainer.fit(args, consumed_samples=0, num_update_steps_per_epoch=num_steps)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment