Implementation:NVIDIA NeMo Aligner KTOTrainer Fit
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
KTOTrainer is the trainer class that coordinates KTO (Kahneman-Tversky Optimization) training, extending the DPOTrainer with a custom dataloader augmentation method that computes reference policy log-probabilities for both the original samples and KL estimation samples.
Description
The KTOTrainer class inherits from DPOTrainer and overrides the augment_dataloader() method to handle KTO-specific data requirements. The parent DPOTrainer provides the full training loop via fit(), including epoch iteration, training steps, validation, checkpointing, and logging.
The key KTO-specific behavior is in the dataloader augmentation:
- Reference policy log-probabilities: For each batch,
augment_dataloader()callsmodel.get_ref_policy_logprobs(batch)to compute log-probabilities under the reference policy. The returned tensor is split into two halves:ref_policy_log_probs_samples(for the original prompt-response pairs) andref_policy_log_probs_kl_samples(for the KL estimation pairs created bykto_custom_collate). - Yielding augmented batches: The augmented batch is yielded to the training loop with the reference log-probs attached, allowing the model's
get_loss_and_metrics()to compute the KTO loss.
The kto_custom_collate function (defined in the same module) handles the construction of KL estimation samples. It pairs each prompt with the response from the next sample in the batch (circular shift), creating mismatched prompt-response pairs used to estimate the KL divergence reference point.
Since KTOTrainer extends DPOTrainer, it inherits the fit() method, which handles the full training loop including epoch management, gradient clipping, validation, and checkpointing.
Usage
Import and instantiate KTOTrainer when setting up a KTO training run. It is typically created in the train_gpt_kto.py entry point script.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/algorithms/kto.py
- Lines: 92-137
Signature
class KTOTrainer(DPOTrainer):
"""Trainer to coordinate KTO training
"""
def __init__(
self,
cfg: DictConfig,
model,
optimizer,
scheduler,
train_dataloader,
val_dataloader,
test_dataloader,
logger,
ckpt_callback,
run_timer,
):
Key Method: augment_dataloader()
def augment_dataloader(self, dataloader):
"""Augment dataloader with ref policy log prob"""
iter_dataloader = iter(dataloader)
while True:
try:
batch = next(iter_dataloader)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
samples_logps, kl_samples_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_samples"] = samples_logps
batch["ref_policy_log_probs_kl_samples"] = kl_samples_logps
yield batch
del logprobs, samples_logps, kl_samples_logps
except StopIteration:
break
Key Function: kto_custom_collate()
def kto_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
Import
from nemo_aligner.algorithms.kto import KTOTrainer, kto_custom_collate
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Trainer configuration (from cfg.trainer.kto), inherited from DPOTrainer. Includes max_epochs, max_steps, gradient_clip_val, val_check_interval, save_interval, limit_val_batches
|
| model | MegatronGPTKTOModel | Yes | The KTO model implementing SupervisedInterface, with get_loss_and_metrics() and get_ref_policy_logprobs()
|
| optimizer | Optimizer | Yes | PyTorch optimizer extracted from the PTL model |
| scheduler | LRScheduler | Yes | Learning rate scheduler extracted from the PTL model |
| train_dataloader | DataLoader | Yes | Training dataloader using kto_custom_collate, providing batches with samples, kl_samples, sample_labels, kl_sample_labels, preference, attention_mask, position_ids
|
| val_dataloader | DataLoader | Yes | Validation dataloader |
| test_dataloader | DataLoader | No | Test dataloader (can be None) |
| logger | CustomLoggerWrapper | Yes | Logger for metrics and table logging |
| ckpt_callback | Callback | Yes | Custom checkpoint callback for saving model states |
| run_timer | Timer | Yes | Timer that enforces maximum training time limits |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | The fit() method (inherited from DPOTrainer) trains the model in-place, logs metrics, and saves checkpoints. No return value.
|
Usage Examples
from nemo_aligner.algorithms.kto import KTOTrainer
kto_trainer = KTOTrainer(
cfg=cfg.trainer.kto,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)
if custom_trainer_state_dict is not None:
kto_trainer.load_state_dict(custom_trainer_state_dict)
kto_trainer.fit()