Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner Retrieve Model State Dict

From Leeroopedia


Implementation Details
Name Retrieve_Model_State_Dict
Type API Doc
Implements Principle DPO_Reference_Policy_Management
Module nemo_aligner.utils
Repository NeMo Aligner
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for creating a CPU copy of model weights for use as a reference policy in DPO training provided by the NeMo Aligner utility module.

Description

The retrieve_model_state_dict_in_cpu function copies the entire model state dictionary to CPU memory, converting all tensors to detached CPU copies. This CPU state dict serves as the frozen reference policy for DPO training. The function handles Megatron AMP O2 format conversion. During training, the model temporarily swaps in reference weights for computing pi_ref log probabilities, then restores training weights.

Usage

Used in DPO training initialization for full-parameter training (not PEFT). Called once before training begins to snapshot the initial model weights.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/utils/utils.py
  • Lines: L369-384

Signature

def retrieve_model_state_dict_in_cpu(
    model,
    megatron_amp_O2: bool = True,
) -> dict:
    """Get a copy of the model states in CPU.

    Args:
        model: The model to copy state from
        megatron_amp_O2: Whether to convert to AMP O2 format

    Returns:
        CPU state dict with all tensors detached and copied
    """

Import

from nemo_aligner.utils.utils import retrieve_model_state_dict_in_cpu

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes Model whose state dict to copy
megatron_amp_O2 bool No Convert to AMP O2 format (default True)

Outputs

Name Type Description
cpu_dict dict Complete state dict with all tensors on CPU

Usage Examples

from nemo_aligner.utils.utils import retrieve_model_state_dict_in_cpu

# Create reference policy snapshot before DPO training
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
    model,
    megatron_amp_O2=cfg.model.megatron_amp_O2,
)

# Later, during forward pass, model swaps in ref weights:
# model.ref_policy_state_dict = ref_policy_state_dict

Related Pages

Knowledge Sources

NLP, Alignment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment