Principle:NVIDIA NeMo Aligner Checkpoint Management
| Principle: Checkpoint Management | |
|---|---|
| Type | Principle |
| Project | NVIDIA NeMo Aligner |
| Domains | MLOps, Training |
| Related Implementations | Implementation:NVIDIA_NeMo_Aligner_Custom_Checkpoint_Callback |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Mechanism for persisting model weights, optimizer state, and training progress to enable resumable training and model deployment.
Description
Checkpoint management in NeMo Aligner handles saving and restoring the full training state:
- Model parameters -- all learnable weights of the neural network
- Optimizer states -- momentum buffers, adaptive learning rate accumulators (e.g., Adam first/second moments)
- Learning rate scheduler state -- current step, warmup progress, decay schedule position
- Custom trainer metadata -- step count, consumed samples, epoch number, and algorithm-specific counters (e.g.,
ppo_optimization_step,consumed_global_samples)
It integrates with NeMo's NeMoModelCheckpoint callback, which manages:
- Checkpoint rotation -- keeping the top-K checkpoints ranked by validation metric, deleting older ones
- Distributed saving -- correctly sharding and saving weights across tensor/pipeline parallel ranks
- The .nemo archive format -- bundling model weights, configuration, and tokenizer into a single deployable artifact
On resume, the trainer state is parsed from the checkpoint path to restore exact training progress, ensuring no training steps are repeated or skipped.
Usage
Use in every training workflow for fault tolerance and training resumption. The checkpoint callback is:
- Added to the PyTorch Lightning trainer's callback list
- Triggered by the algorithm trainer at configurable save intervals (every N steps or N epochs)
- Responsible for both periodic saves and best-model saves (based on validation metric)
The saved .nemo checkpoints serve a dual purpose:
- Training artifact -- enables resumption after interruption or failure
- Deployment artifact -- the same checkpoint format is used for inference and serving
Theoretical Basis
The principle is grounded in stateful training resumption. The complete state that must be persisted includes:
Checkpoint State Dict:
1. Model weights (sharded across TP/PP ranks)
- Each rank saves only its shard of the parameters
- Sharding is determined by tensor_model_parallel_size and pipeline_model_parallel_size
2. Optimizer state (including momentum buffers)
- For Adam: first moment (m), second moment (v), step count per parameter
- Must match the model parameter sharding
3. Trainer metadata
- global_step: total optimizer steps completed
- consumed_samples: total training examples processed
- epoch: current epoch number
- Algorithm-specific: ppo_optimization_step, dpo_step, etc.
The checkpoint path encodes metadata for parsing on restore:
Path format:
checkpoints/megatron_gpt--val_loss=0.432-step=1000-consumed_samples=32000.0.ckpt
Parsed on restore to recover:
- step = 1000
- consumed_samples = 32000
- val_loss = 0.432 (for checkpoint ranking)
The checkpoint rotation strategy ensures disk usage is bounded:
- After each save, rank checkpoints by the monitored validation metric
- Keep only the top-K checkpoints (configurable via
save_top_k) - Always retain the last checkpoint regardless of metric ranking
- Optionally save a best checkpoint symlink for easy access