Implementation:NVIDIA NeMo Aligner Train GPT KTO
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_gpt_kto.py is the entry point script for launching KTO (Kahneman-Tversky Optimization) training of a GPT model using NeMo Aligner.
Description
This script wires together all components required for KTO training:
- Configuration loading: Uses Hydra (
@hydra_runner) with config pathconfand config namegpt_kto. Loads and overrides the model config from the pretrained checkpoint. - Trainer and experiment setup: Creates a PyTorch Lightning trainer via
resolve_and_create_trainer(cfg, "kto")and initializes experiment management. - Model loading: Loads a pretrained
MegatronGPTKTOModelfrom a NeMo checkpoint withload_base_model_only=False. Initializes PEFT adapters viainit_peft(). For full fine-tuning (whenpeft_scheme == "none"), the reference policy state dict is initialized from the current model weights. - Data preparation: Builds KTO train/validation/test datasets using
build_train_valid_test_kto_datasets(). Useskto_custom_collate(withfunctools.partial) as the collate function, which constructs both the original samples and the KL estimation samples from mismatched prompt-response pairs. - Optimizer and scheduler: Extracts the optimizer and scheduler from the PTL model.
- KTOTrainer instantiation: Creates the
KTOTrainerwith all dependencies, optionally restores trainer state from a checkpoint, and callskto_trainer.fit().
The script registers custom OmegaConf resolvers for multiply and int_div operations.
Usage
Run this script via the command line with Hydra configuration overrides to launch KTO training. It requires a pretrained NeMo GPT checkpoint and KTO-format data with binary preference labels.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/nlp/gpt/train_gpt_kto.py
- Lines: 1-164
Signature
@hydra_runner(config_path="conf", config_name="gpt_kto")
def main(cfg) -> None:
Import
from nemo_aligner.algorithms.kto import KTOTrainer, kto_custom_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_kto_model import MegatronGPTKTOModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Hydra configuration object containing pretrained_checkpoint.restore_from_path, model, trainer, exp_manager, and data configuration
|
| cfg.pretrained_checkpoint.restore_from_path | str | Yes | Path to the pretrained NeMo GPT model checkpoint |
| cfg.model.peft.peft_scheme | str | Yes | PEFT scheme to use; set to "none" for full fine-tuning
|
| cfg.model.kto.ref_policy_kl_penalty | float | Yes | Beta parameter controlling the strength of the KL penalty |
| cfg.model.kto.desirable_loss_weight | float | No | Weight for desirable (chosen) samples in the loss (default: 1.0) |
| cfg.model.kto.undesirable_loss_weight | float | No | Weight for undesirable (rejected) samples in the loss (default: 1.0) |
| cfg.model.data.data_prefix | str | Yes | Path prefix for the KTO training data |
| cfg.model.data.splits_string | str | Yes | Train/val/test split specification |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | Trains the model in-place, saves checkpoints, and logs metrics. No return value. |
Usage Examples
# Command-line invocation:
# python examples/nlp/gpt/train_gpt_kto.py \
# pretrained_checkpoint.restore_from_path=/path/to/model.nemo \
# model.kto.ref_policy_kl_penalty=0.1 \
# model.kto.desirable_loss_weight=1.0 \
# model.kto.undesirable_loss_weight=1.0 \
# model.peft.peft_scheme=none \
# model.data.data_prefix=/path/to/kto_data
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment