Implementation:NVIDIA NeMo Aligner MegatronGPT SPIN Model
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
MegatronGPTSPINModel is the model class for SPIN (Self-Play Fine-Tuning) training, extending MegatronGPTModel with a DPO-style loss function, reference policy management, and self-play specific forward passes.
Description
The MegatronGPTSPINModel class extends MegatronGPTModel and implements the SupervisedInterface. It is adapted from the paper "Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models" (Chen et al., 2024, https://arxiv.org/abs/2401.01335).
Key capabilities include:
- DPO-style loss: The
loss_func()method computes-log sigmoid(kl_penalty * (chosen_rewards - reject_rewards)), where rewards are the sum of masked log-probability differences between the policy and reference. "Chosen" corresponds to ground-truth responses and "rejected" corresponds to model-generated responses. - KL penalty scheduling: The
ref_policy_kl_penaltyparameter can be a scalar or a list (one value per iteration). Theset_KL_penalty_by_iteration()method selects the appropriate penalty for the current iteration. - Reference policy management: Maintains
ref_policy_state_dictin CPU memory. Theget_ref_policy_logprobs()method temporarily swaps in reference weights usingcpu_weight_swap()to compute log-probabilities under the reference policy for both actual and generated responses. - Forward pass: The
get_forward_output_and_loss_func()method handles the concatenation of actual and generated responses into a single batch (doubling the effective micro-batch size), forwards through the model, and computes per-token log-probabilities and the SPIN loss. - Validation: The
get_loss_and_metrics_vanilla_sft()method provides an efficient validation path using standard SFT loss rather than SPIN loss, avoiding the need for costly generation during validation. - Distributed checkpointing: Implements custom
sharded_state_dict()andon_load_checkpoint()methods to save and restore both the main model weights and the reference policy weights using Megatron's distributed checkpointing. - Adam state offloading: Supports offloading distributed Adam optimizer states to CPU during inference via
offload_adam_states()andonload_adam_states().
Usage
Import MegatronGPTSPINModel when loading the pretrained model for SPIN training. It is loaded from a NeMo checkpoint in the training entry script, and the reference policy state dict is initialized from the initial model weights.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/nlp/gpt/megatron_gpt_spin_model.py
- Lines: 51-639
Signature
class MegatronGPTSPINModel(MegatronGPTModel, SupervisedInterface):
"""
Megatron GPT SPIN Model Training
Adapted from the paper Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models
(Chen, et al, 2024)
https://arxiv.org/abs/2401.01335
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
Import
from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Model configuration including spin.ref_policy_kl_penalty (scalar or list), spin.offload_adam_states, spin.length_params, spin.sampling_params, spin.rollout_micro_batch_size, spin.log_prob_forward_micro_batch_size
|
| trainer | pytorch_lightning.Trainer | Yes | PyTorch Lightning trainer instance |
Key Method: get_loss_and_metrics()
| Name | Type | Description |
|---|---|---|
| loss_mean | float | Mean SPIN loss value |
| metrics | dict | Dictionary containing loss, acc (accuracy of choosing actual over generated), rewards_actual_mean, rewards_generated_mean, rewards_all_mean, rewards_all_std
|
Key Method: get_ref_policy_logprobs()
| Name | Type | Description |
|---|---|---|
| ref_log_probs | Tensor | Per-token log-probabilities under the reference policy for concatenated actual and generated responses, shape [2*B, seq_len-1]
|
Outputs
| Name | Type | Description |
|---|---|---|
| loss_mean | float | Scalar loss value for the training step |
| metrics | dict | Training metrics dictionary |
Usage Examples
from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel
from nemo_aligner.utils.utils import load_from_nemo, retrieve_model_state_dict_in_cpu
ptl_model = load_from_nemo(
MegatronGPTSPINModel,
cfg.model,
trainer,
strict=True,
restore_path=cfg.pretrained_checkpoint.restore_from_path,
)
# Initialize reference policy from the pretrained model weights
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False)
)
ptl_model.ref_policy_state_dict = ref_policy_state_dict
# Set KL penalty for iteration 0
ptl_model.set_KL_penalty_by_iteration(0)