Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner Custom Checkpoint Callback

From Leeroopedia


Implementation Metadata
Name Custom_Checkpoint_Callback
Type API Doc
Implements Principle Checkpoint_Management
Repository NeMo Aligner
File nemo_aligner/utils/train_script_utils.py
Lines L34-149
Domains MLOps, Training
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for managing model checkpoint saving and training state restoration provided by the NeMo Aligner training utilities.

Description

The add_custom_checkpoint_callback function finds the NeMoModelCheckpoint callback in the PTL trainer's callback list and attaches a custom_save method for use by algorithm trainers. The retrieve_custom_trainer_state_dict function parses checkpoint paths to extract saved training state (step, consumed_samples, epoch, algorithm-specific counters). Together they enable seamless checkpoint save/restore across all alignment training workflows.

Usage

Import in training scripts during initialization. add_custom_checkpoint_callback is called after PTL trainer setup. retrieve_custom_trainer_state_dict is called during checkpoint restoration to resume training from the correct step.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/utils/train_script_utils.py
  • Lines: L34-149

Signature

def retrieve_custom_trainer_state_dict(ptl_trainer) -> Optional[dict]:
    """Parse trainer state from loaded checkpoint path.
    Returns dict with keys: step, consumed_samples, epoch,
    ppo_optimization_step (PPO), reinforce_optimization_step (REINFORCE)."""

def add_custom_checkpoint_callback(ptl_trainer, ptl_model) -> NeMoModelCheckpoint:
    """Get a function to conveniently save checkpoints within the trainer.
    Returns the checkpoint callback with custom_save method attached."""

Import

from nemo_aligner.utils.train_script_utils import (
    add_custom_checkpoint_callback,
    retrieve_custom_trainer_state_dict,
)

I/O Contract

Inputs (add_custom_checkpoint_callback)

Name Type Required Description
ptl_trainer pytorch_lightning.Trainer Yes PTL trainer with callbacks list
ptl_model Model Yes Model to save

Outputs (add_custom_checkpoint_callback)

Name Type Description
callback NeMoModelCheckpoint Checkpoint callback with custom_save(metrics, is_train_end) method

Inputs (retrieve_custom_trainer_state_dict)

Name Type Required Description
ptl_trainer pytorch_lightning.Trainer Yes PTL trainer with restored checkpoint path

Outputs (retrieve_custom_trainer_state_dict)

Name Type Description
state_dict Optional[dict] Dict with step, consumed_samples, epoch, and algorithm-specific fields

Usage Examples

Setting Up Checkpoint Callback and Restoring State

from nemo_aligner.utils.train_script_utils import (
    add_custom_checkpoint_callback,
    retrieve_custom_trainer_state_dict,
)

# Setup checkpoint callback
ckpt_callback = add_custom_checkpoint_callback(ptl_trainer, model)

# Restore training state if resuming
trainer_state = retrieve_custom_trainer_state_dict(ptl_trainer)
if trainer_state:
    step = trainer_state["step"]
    consumed_samples = trainer_state["consumed_samples"]

Related Pages

Knowledge Sources

MLOps | Training

2026-02-07 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment