Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner MegatronGPT SPIN Model

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

MegatronGPTSPINModel is the model class for SPIN (Self-Play Fine-Tuning) training, extending MegatronGPTModel with a DPO-style loss function, reference policy management, and self-play specific forward passes.

Description

The MegatronGPTSPINModel class extends MegatronGPTModel and implements the SupervisedInterface. It is adapted from the paper "Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models" (Chen et al., 2024, https://arxiv.org/abs/2401.01335).

Key capabilities include:

  • DPO-style loss: The loss_func() method computes -log sigmoid(kl_penalty * (chosen_rewards - reject_rewards)), where rewards are the sum of masked log-probability differences between the policy and reference. "Chosen" corresponds to ground-truth responses and "rejected" corresponds to model-generated responses.
  • KL penalty scheduling: The ref_policy_kl_penalty parameter can be a scalar or a list (one value per iteration). The set_KL_penalty_by_iteration() method selects the appropriate penalty for the current iteration.
  • Reference policy management: Maintains ref_policy_state_dict in CPU memory. The get_ref_policy_logprobs() method temporarily swaps in reference weights using cpu_weight_swap() to compute log-probabilities under the reference policy for both actual and generated responses.
  • Forward pass: The get_forward_output_and_loss_func() method handles the concatenation of actual and generated responses into a single batch (doubling the effective micro-batch size), forwards through the model, and computes per-token log-probabilities and the SPIN loss.
  • Validation: The get_loss_and_metrics_vanilla_sft() method provides an efficient validation path using standard SFT loss rather than SPIN loss, avoiding the need for costly generation during validation.
  • Distributed checkpointing: Implements custom sharded_state_dict() and on_load_checkpoint() methods to save and restore both the main model weights and the reference policy weights using Megatron's distributed checkpointing.
  • Adam state offloading: Supports offloading distributed Adam optimizer states to CPU during inference via offload_adam_states() and onload_adam_states().

Usage

Import MegatronGPTSPINModel when loading the pretrained model for SPIN training. It is loaded from a NeMo checkpoint in the training entry script, and the reference policy state dict is initialized from the initial model weights.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/nlp/gpt/megatron_gpt_spin_model.py
  • Lines: 51-639

Signature

class MegatronGPTSPINModel(MegatronGPTModel, SupervisedInterface):
    """
    Megatron GPT SPIN Model Training
    Adapted from the paper Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models
    (Chen, et al, 2024)
    https://arxiv.org/abs/2401.01335
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer):

Import

from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Model configuration including spin.ref_policy_kl_penalty (scalar or list), spin.offload_adam_states, spin.length_params, spin.sampling_params, spin.rollout_micro_batch_size, spin.log_prob_forward_micro_batch_size
trainer pytorch_lightning.Trainer Yes PyTorch Lightning trainer instance

Key Method: get_loss_and_metrics()

Name Type Description
loss_mean float Mean SPIN loss value
metrics dict Dictionary containing loss, acc (accuracy of choosing actual over generated), rewards_actual_mean, rewards_generated_mean, rewards_all_mean, rewards_all_std

Key Method: get_ref_policy_logprobs()

Name Type Description
ref_log_probs Tensor Per-token log-probabilities under the reference policy for concatenated actual and generated responses, shape [2*B, seq_len-1]

Outputs

Name Type Description
loss_mean float Scalar loss value for the training step
metrics dict Training metrics dictionary

Usage Examples

from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel
from nemo_aligner.utils.utils import load_from_nemo, retrieve_model_state_dict_in_cpu

ptl_model = load_from_nemo(
    MegatronGPTSPINModel,
    cfg.model,
    trainer,
    strict=True,
    restore_path=cfg.pretrained_checkpoint.restore_from_path,
)

# Initialize reference policy from the pretrained model weights
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
    ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False)
)
ptl_model.ref_policy_state_dict = ref_policy_state_dict

# Set KL penalty for iteration 0
ptl_model.set_KL_penalty_by_iteration(0)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment