Implementation:NVIDIA NeMo Aligner Train GPT SPIN
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_gpt_spin.py is the entry point script for launching SPIN (Self-Play Fine-Tuning) training of a GPT model using NeMo Aligner.
Description
This script wires together all components required for SPIN training:
- Configuration loading: Uses Hydra (
@hydra_runner) with config pathconfand config namegpt_spin. Loads and overrides the model config from the pretrained checkpoint. Also overridescfg.model.encoder_seq_lengthfrom the checkpoint. - Trainer and experiment setup: Creates a PyTorch Lightning trainer via
resolve_and_create_trainer(cfg, "spin")and initializes experiment management. - Model loading: Loads a pretrained
MegatronGPTSPINModelfrom a NeMo checkpoint. If the model does not already have a reference policy state dict (i.e., fresh start), it initializes one from the current model weights viaretrieve_model_state_dict_in_cpu(). - Data preparation: Builds SFT train/validation datasets using
build_sft_dataset()with chat format support. Usesspin_custom_collateas the collate function for training data. Optionally samples a subset of data based onmax_steps * global_batch_size. - Optimizer and scheduler: Extracts the optimizer and scheduler from the PTL model.
- SPINTrainer instantiation: Creates the
SPINTrainerwith all dependencies, optionally restores trainer state from a checkpoint, and callsspin_trainer.fit().
The script registers custom OmegaConf resolvers for multiply, int_div, and subtract operations used in configuration interpolation.
Usage
Run this script via the command line with Hydra configuration overrides to launch SPIN training. It requires a pretrained NeMo GPT checkpoint and SFT-format chat data.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/nlp/gpt/train_gpt_spin.py
- Lines: 1-192
Signature
@hydra_runner(config_path="conf", config_name="gpt_spin")
def main(cfg) -> None:
Import
from nemo_aligner.algorithms.spin import SPINTrainer, spin_custom_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_spin_model import MegatronGPTSPINModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Hydra configuration object containing pretrained_checkpoint.restore_from_path, model, trainer, exp_manager, and data configuration
|
| cfg.pretrained_checkpoint.restore_from_path | str | Yes | Path to the pretrained NeMo GPT model checkpoint |
| cfg.model.data.chat | bool | Yes | Whether the data is in chat format |
| cfg.model.data.chat_prompt_tokens | dict | No | Special tokens for chat prompt formatting |
| cfg.model.data.train_ds | DictConfig | Yes | Training dataset configuration |
| cfg.model.data.validation_ds | DictConfig | Yes | Validation dataset configuration |
| cfg.trainer.spin.max_iterations | int | Yes | Number of SPIN iterations (outer loop); after each, the reference policy is updated |
| cfg.trainer.spin.max_epochs | int | Yes | Number of epochs per iteration |
| cfg.model.spin.ref_policy_kl_penalty | float or list | Yes | KL penalty scalar or per-iteration schedule |
| cfg.model.spin.rollout_micro_batch_size | int | Yes | Micro batch size for generation rollouts |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | Trains the model in-place, saves checkpoints, and logs metrics. No return value. |
Usage Examples
# Command-line invocation:
# python examples/nlp/gpt/train_gpt_spin.py \
# pretrained_checkpoint.restore_from_path=/path/to/model.nemo \
# trainer.spin.max_iterations=3 \
# trainer.spin.max_epochs=2 \
# model.spin.ref_policy_kl_penalty=0.1 \
# model.data.chat=True
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment