Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner MegatronGPT KTO Model

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

MegatronGPTKTOModel is the model class for KTO (Kahneman-Tversky Optimization) training, extending MegatronGPTModel with a prospect-theory-inspired loss function that uses binary feedback and KL divergence estimation from mismatched prompt-response pairs.

Description

The MegatronGPTKTOModel class combines three base classes:

  • NLPAdapterModelMixin -- enables PEFT (Parameter-Efficient Fine-Tuning) adapter support.
  • MegatronGPTModel -- provides the core Megatron-LM GPT model with tensor/pipeline parallelism.
  • SupervisedInterface -- defines the interface contract for supervised alignment training.

Key capabilities include:

  • KTO loss function: The loss_func() method implements the KTO loss. It splits rewards into chosen, rejected, and KL components based on preference labels. The KL divergence (clamped to non-negative) serves as the reference point. Desirable losses use 1 - sigmoid(beta * (chosen_reward - KL_ref)) and undesirable losses use 1 - sigmoid(beta * (KL_ref - reject_reward)). Asymmetric weighting via desirable_loss_weight and undesirable_loss_weight is applied.
  • Forward pass: The get_forward_output_and_loss_func() method concatenates original samples with KL estimation samples (doubling the batch), runs a single forward pass, computes per-token log-probabilities, and applies the KTO loss. It also supports a logprobs_only mode for computing reference policy log-probabilities.
  • Reference policy log-probabilities: The get_ref_policy_logprobs() method computes log-probabilities under the reference policy by either swapping in reference weights (cpu_weight_swap) for full fine-tuning, or disabling adapters (adapter_control) for PEFT.
  • Configurable loss components: Supports combining a preference loss with an optional SFT loss via preference_loss_weight and sft_loss_weight. Supports averaging log-probabilities over tokens via preference_average_log_probs and sft_average_log_probs.
  • Split output tensor: The split_output_tensor() method handles the three-way split of outputs (chosen, rejected, KL) based on preference labels, unlike DPO which always splits in half.

Usage

Import MegatronGPTKTOModel when loading the pretrained model for KTO training. It is loaded from a NeMo checkpoint in the training entry script, and the reference policy state dict is initialized from the initial model weights (for full fine-tuning) or handled via adapter control (for PEFT).

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/nlp/gpt/megatron_gpt_kto_model.py
  • Lines: 46-420

Signature

class MegatronGPTKTOModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface):
    """
    Megatron GPT KTO Model Training.
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer):

Import

from nemo_aligner.models.nlp.gpt.megatron_gpt_kto_model import MegatronGPTKTOModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Model configuration including kto.ref_policy_kl_penalty (beta), kto.preference_average_log_probs, kto.sft_average_log_probs, kto.preference_loss_weight, kto.sft_loss_weight, kto.desirable_loss_weight, kto.undesirable_loss_weight, kto.log_prob_forward_micro_batch_size
trainer pytorch_lightning.Trainer Yes PyTorch Lightning trainer instance

Key Method: get_loss_and_metrics()

Name Type Description
loss_mean float Mean KTO loss value
metrics dict Dictionary containing loss, kl (KL divergence), rewards_chosen_mean, rewards_rejected_mean, rewards_all_mean, rewards_all_std, rewards_margin

Key Method: loss_func()

Name Type Description
loss Tensor Scalar KTO loss
kl_divergence Tensor Estimated KL divergence from mismatched samples (clamped to non-negative)

Outputs

Name Type Description
loss_mean float Scalar loss value for the training step
metrics dict Training metrics dictionary

Usage Examples

from nemo_aligner.models.nlp.gpt.megatron_gpt_kto_model import MegatronGPTKTOModel
from nemo_aligner.utils.utils import load_from_nemo, retrieve_model_state_dict_in_cpu

ptl_model = load_from_nemo(
    MegatronGPTKTOModel,
    cfg.model,
    trainer,
    strict=True,
    load_base_model_only=False,
    restore_path=cfg.pretrained_checkpoint.restore_from_path,
)

# Initialize PEFT adapters
init_peft(ptl_model, cfg.model)

# For full fine-tuning, initialize reference policy from current weights
if cfg.model.peft.peft_scheme == "none":
    ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
        ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False)
    )
    ptl_model.ref_policy_state_dict = ref_policy_state_dict

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment