Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner KTOTrainer Fit

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

KTOTrainer is the trainer class that coordinates KTO (Kahneman-Tversky Optimization) training, extending the DPOTrainer with a custom dataloader augmentation method that computes reference policy log-probabilities for both the original samples and KL estimation samples.

Description

The KTOTrainer class inherits from DPOTrainer and overrides the augment_dataloader() method to handle KTO-specific data requirements. The parent DPOTrainer provides the full training loop via fit(), including epoch iteration, training steps, validation, checkpointing, and logging.

The key KTO-specific behavior is in the dataloader augmentation:

  • Reference policy log-probabilities: For each batch, augment_dataloader() calls model.get_ref_policy_logprobs(batch) to compute log-probabilities under the reference policy. The returned tensor is split into two halves: ref_policy_log_probs_samples (for the original prompt-response pairs) and ref_policy_log_probs_kl_samples (for the KL estimation pairs created by kto_custom_collate).
  • Yielding augmented batches: The augmented batch is yielded to the training loop with the reference log-probs attached, allowing the model's get_loss_and_metrics() to compute the KTO loss.

The kto_custom_collate function (defined in the same module) handles the construction of KL estimation samples. It pairs each prompt with the response from the next sample in the batch (circular shift), creating mismatched prompt-response pairs used to estimate the KL divergence reference point.

Since KTOTrainer extends DPOTrainer, it inherits the fit() method, which handles the full training loop including epoch management, gradient clipping, validation, and checkpointing.

Usage

Import and instantiate KTOTrainer when setting up a KTO training run. It is typically created in the train_gpt_kto.py entry point script.

Code Reference

Source Location

Signature

class KTOTrainer(DPOTrainer):
    """Trainer to coordinate KTO training
    """

    def __init__(
        self,
        cfg: DictConfig,
        model,
        optimizer,
        scheduler,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        logger,
        ckpt_callback,
        run_timer,
    ):

Key Method: augment_dataloader()

def augment_dataloader(self, dataloader):
    """Augment dataloader with ref policy log prob"""
    iter_dataloader = iter(dataloader)
    while True:
        try:
            batch = next(iter_dataloader)
            logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
            samples_logps, kl_samples_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
            batch["ref_policy_log_probs_samples"] = samples_logps
            batch["ref_policy_log_probs_kl_samples"] = kl_samples_logps
            yield batch
            del logprobs, samples_logps, kl_samples_logps
        except StopIteration:
            break

Key Function: kto_custom_collate()

def kto_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):

Import

from nemo_aligner.algorithms.kto import KTOTrainer, kto_custom_collate

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Trainer configuration (from cfg.trainer.kto), inherited from DPOTrainer. Includes max_epochs, max_steps, gradient_clip_val, val_check_interval, save_interval, limit_val_batches
model MegatronGPTKTOModel Yes The KTO model implementing SupervisedInterface, with get_loss_and_metrics() and get_ref_policy_logprobs()
optimizer Optimizer Yes PyTorch optimizer extracted from the PTL model
scheduler LRScheduler Yes Learning rate scheduler extracted from the PTL model
train_dataloader DataLoader Yes Training dataloader using kto_custom_collate, providing batches with samples, kl_samples, sample_labels, kl_sample_labels, preference, attention_mask, position_ids
val_dataloader DataLoader Yes Validation dataloader
test_dataloader DataLoader No Test dataloader (can be None)
logger CustomLoggerWrapper Yes Logger for metrics and table logging
ckpt_callback Callback Yes Custom checkpoint callback for saving model states
run_timer Timer Yes Timer that enforces maximum training time limits

Outputs

Name Type Description
None (side effects) N/A The fit() method (inherited from DPOTrainer) trains the model in-place, logs metrics, and saves checkpoints. No return value.

Usage Examples

from nemo_aligner.algorithms.kto import KTOTrainer

kto_trainer = KTOTrainer(
    cfg=cfg.trainer.kto,
    model=ptl_model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=None,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
)

if custom_trainer_state_dict is not None:
    kto_trainer.load_state_dict(custom_trainer_state_dict)

kto_trainer.fit()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment