Overview
Concrete tool for managing model checkpoint saving and training state restoration provided by the NeMo Aligner training utilities.
Description
The add_custom_checkpoint_callback function finds the NeMoModelCheckpoint callback in the PTL trainer's callback list and attaches a custom_save method for use by algorithm trainers. The retrieve_custom_trainer_state_dict function parses checkpoint paths to extract saved training state (step, consumed_samples, epoch, algorithm-specific counters). Together they enable seamless checkpoint save/restore across all alignment training workflows.
Usage
Import in training scripts during initialization. add_custom_checkpoint_callback is called after PTL trainer setup. retrieve_custom_trainer_state_dict is called during checkpoint restoration to resume training from the correct step.
Code Reference
Source Location
- Repository: NeMo Aligner
- File:
nemo_aligner/utils/train_script_utils.py
- Lines: L34-149
Signature
def retrieve_custom_trainer_state_dict(ptl_trainer) -> Optional[dict]:
"""Parse trainer state from loaded checkpoint path.
Returns dict with keys: step, consumed_samples, epoch,
ppo_optimization_step (PPO), reinforce_optimization_step (REINFORCE)."""
def add_custom_checkpoint_callback(ptl_trainer, ptl_model) -> NeMoModelCheckpoint:
"""Get a function to conveniently save checkpoints within the trainer.
Returns the checkpoint callback with custom_save method attached."""
Import
from nemo_aligner.utils.train_script_utils import (
add_custom_checkpoint_callback,
retrieve_custom_trainer_state_dict,
)
I/O Contract
Inputs (add_custom_checkpoint_callback)
| Name |
Type |
Required |
Description
|
| ptl_trainer |
pytorch_lightning.Trainer |
Yes |
PTL trainer with callbacks list
|
| ptl_model |
Model |
Yes |
Model to save
|
Outputs (add_custom_checkpoint_callback)
| Name |
Type |
Description
|
| callback |
NeMoModelCheckpoint |
Checkpoint callback with custom_save(metrics, is_train_end) method
|
Inputs (retrieve_custom_trainer_state_dict)
| Name |
Type |
Required |
Description
|
| ptl_trainer |
pytorch_lightning.Trainer |
Yes |
PTL trainer with restored checkpoint path
|
Outputs (retrieve_custom_trainer_state_dict)
| Name |
Type |
Description
|
| state_dict |
Optional[dict] |
Dict with step, consumed_samples, epoch, and algorithm-specific fields
|
Usage Examples
Setting Up Checkpoint Callback and Restoring State
from nemo_aligner.utils.train_script_utils import (
add_custom_checkpoint_callback,
retrieve_custom_trainer_state_dict,
)
# Setup checkpoint callback
ckpt_callback = add_custom_checkpoint_callback(ptl_trainer, model)
# Restore training state if resuming
trainer_state = retrieve_custom_trainer_state_dict(ptl_trainer)
if trainer_state:
step = trainer_state["step"]
consumed_samples = trainer_state["consumed_samples"]
Related Pages
Knowledge Sources
MLOps | Training
2026-02-07 00:00 GMT
Page Connections
Double-click a node to navigate. Hold to expand connections.