Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Train GPT SPIN

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

train_gpt_spin.py is the entry point script for launching SPIN (Self-Play Fine-Tuning) training of a GPT model using NeMo Aligner.

Description

This script wires together all components required for SPIN training:

  1. Configuration loading: Uses Hydra (@hydra_runner) with config path conf and config name gpt_spin. Loads and overrides the model config from the pretrained checkpoint. Also overrides cfg.model.encoder_seq_length from the checkpoint.
  2. Trainer and experiment setup: Creates a PyTorch Lightning trainer via resolve_and_create_trainer(cfg, "spin") and initializes experiment management.
  3. Model loading: Loads a pretrained MegatronGPTSPINModel from a NeMo checkpoint. If the model does not already have a reference policy state dict (i.e., fresh start), it initializes one from the current model weights via retrieve_model_state_dict_in_cpu().
  4. Data preparation: Builds SFT train/validation datasets using build_sft_dataset() with chat format support. Uses spin_custom_collate as the collate function for training data. Optionally samples a subset of data based on max_steps * global_batch_size.
  5. Optimizer and scheduler: Extracts the optimizer and scheduler from the PTL model.
  6. SPINTrainer instantiation: Creates the SPINTrainer with all dependencies, optionally restores trainer state from a checkpoint, and calls spin_trainer.fit().

The script registers custom OmegaConf resolvers for multiply, int_div, and subtract operations used in configuration interpolation.

Usage

Run this script via the command line with Hydra configuration overrides to launch SPIN training. It requires a pretrained NeMo GPT checkpoint and SFT-format chat data.

Code Reference

Source Location

Signature

@hydra_runner(config_path="conf", config_name="gpt_spin")
def main(cfg) -> None:

Import

from nemo_aligner.algorithms.spin import SPINTrainer, spin_custom_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Hydra configuration object containing pretrained_checkpoint.restore_from_path, model, trainer, exp_manager, and data configuration
cfg.pretrained_checkpoint.restore_from_path str Yes Path to the pretrained NeMo GPT model checkpoint
cfg.model.data.chat bool Yes Whether the data is in chat format
cfg.model.data.chat_prompt_tokens dict No Special tokens for chat prompt formatting
cfg.model.data.train_ds DictConfig Yes Training dataset configuration
cfg.model.data.validation_ds DictConfig Yes Validation dataset configuration
cfg.trainer.spin.max_iterations int Yes Number of SPIN iterations (outer loop); after each, the reference policy is updated
cfg.trainer.spin.max_epochs int Yes Number of epochs per iteration
cfg.model.spin.ref_policy_kl_penalty float or list Yes KL penalty scalar or per-iteration schedule
cfg.model.spin.rollout_micro_batch_size int Yes Micro batch size for generation rollouts

Outputs

Name Type Description
None (side effects) N/A Trains the model in-place, saves checkpoints, and logs metrics. No return value.

Usage Examples

# Command-line invocation:
# python examples/nlp/gpt/train_gpt_spin.py \
#     pretrained_checkpoint.restore_from_path=/path/to/model.nemo \
#     trainer.spin.max_iterations=3 \
#     trainer.spin.max_epochs=2 \
#     model.spin.ref_policy_kl_penalty=0.1 \
#     model.data.chat=True

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment