Implementation:NVIDIA NeMo Aligner MegatronGPT KTO Model
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
MegatronGPTKTOModel is the model class for KTO (Kahneman-Tversky Optimization) training, extending MegatronGPTModel with a prospect-theory-inspired loss function that uses binary feedback and KL divergence estimation from mismatched prompt-response pairs.
Description
The MegatronGPTKTOModel class combines three base classes:
- NLPAdapterModelMixin -- enables PEFT (Parameter-Efficient Fine-Tuning) adapter support.
- MegatronGPTModel -- provides the core Megatron-LM GPT model with tensor/pipeline parallelism.
- SupervisedInterface -- defines the interface contract for supervised alignment training.
Key capabilities include:
- KTO loss function: The
loss_func()method implements the KTO loss. It splits rewards into chosen, rejected, and KL components based on preference labels. The KL divergence (clamped to non-negative) serves as the reference point. Desirable losses use1 - sigmoid(beta * (chosen_reward - KL_ref))and undesirable losses use1 - sigmoid(beta * (KL_ref - reject_reward)). Asymmetric weighting viadesirable_loss_weightandundesirable_loss_weightis applied. - Forward pass: The
get_forward_output_and_loss_func()method concatenates original samples with KL estimation samples (doubling the batch), runs a single forward pass, computes per-token log-probabilities, and applies the KTO loss. It also supports alogprobs_onlymode for computing reference policy log-probabilities. - Reference policy log-probabilities: The
get_ref_policy_logprobs()method computes log-probabilities under the reference policy by either swapping in reference weights (cpu_weight_swap) for full fine-tuning, or disabling adapters (adapter_control) for PEFT. - Configurable loss components: Supports combining a preference loss with an optional SFT loss via
preference_loss_weightandsft_loss_weight. Supports averaging log-probabilities over tokens viapreference_average_log_probsandsft_average_log_probs. - Split output tensor: The
split_output_tensor()method handles the three-way split of outputs (chosen, rejected, KL) based on preference labels, unlike DPO which always splits in half.
Usage
Import MegatronGPTKTOModel when loading the pretrained model for KTO training. It is loaded from a NeMo checkpoint in the training entry script, and the reference policy state dict is initialized from the initial model weights (for full fine-tuning) or handled via adapter control (for PEFT).
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/nlp/gpt/megatron_gpt_kto_model.py
- Lines: 46-420
Signature
class MegatronGPTKTOModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface):
"""
Megatron GPT KTO Model Training.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
Import
from nemo_aligner.models.nlp.gpt.megatron_gpt_kto_model import MegatronGPTKTOModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Model configuration including kto.ref_policy_kl_penalty (beta), kto.preference_average_log_probs, kto.sft_average_log_probs, kto.preference_loss_weight, kto.sft_loss_weight, kto.desirable_loss_weight, kto.undesirable_loss_weight, kto.log_prob_forward_micro_batch_size
|
| trainer | pytorch_lightning.Trainer | Yes | PyTorch Lightning trainer instance |
Key Method: get_loss_and_metrics()
| Name | Type | Description |
|---|---|---|
| loss_mean | float | Mean KTO loss value |
| metrics | dict | Dictionary containing loss, kl (KL divergence), rewards_chosen_mean, rewards_rejected_mean, rewards_all_mean, rewards_all_std, rewards_margin
|
Key Method: loss_func()
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Scalar KTO loss |
| kl_divergence | Tensor | Estimated KL divergence from mismatched samples (clamped to non-negative) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss_mean | float | Scalar loss value for the training step |
| metrics | dict | Training metrics dictionary |
Usage Examples
from nemo_aligner.models.nlp.gpt.megatron_gpt_kto_model import MegatronGPTKTOModel
from nemo_aligner.utils.utils import load_from_nemo, retrieve_model_state_dict_in_cpu
ptl_model = load_from_nemo(
MegatronGPTKTOModel,
cfg.model,
trainer,
strict=True,
load_base_model_only=False,
restore_path=cfg.pretrained_checkpoint.restore_from_path,
)
# Initialize PEFT adapters
init_peft(ptl_model, cfg.model)
# For full fine-tuning, initialize reference policy from current weights
if cfg.model.peft.peft_scheme == "none":
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False)
)
ptl_model.ref_policy_state_dict = ref_policy_state_dict