Implementation:NVIDIA NeMo Aligner Retrieve Model State Dict
| Implementation Details | |
|---|---|
| Name | Retrieve_Model_State_Dict |
| Type | API Doc |
| Implements Principle | DPO_Reference_Policy_Management |
| Module | nemo_aligner.utils |
| Repository | NeMo Aligner |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for creating a CPU copy of model weights for use as a reference policy in DPO training provided by the NeMo Aligner utility module.
Description
The retrieve_model_state_dict_in_cpu function copies the entire model state dictionary to CPU memory, converting all tensors to detached CPU copies. This CPU state dict serves as the frozen reference policy for DPO training. The function handles Megatron AMP O2 format conversion. During training, the model temporarily swaps in reference weights for computing pi_ref log probabilities, then restores training weights.
Usage
Used in DPO training initialization for full-parameter training (not PEFT). Called once before training begins to snapshot the initial model weights.
Code Reference
Source Location
- Repository: NeMo Aligner
- File:
nemo_aligner/utils/utils.py - Lines: L369-384
Signature
def retrieve_model_state_dict_in_cpu(
model,
megatron_amp_O2: bool = True,
) -> dict:
"""Get a copy of the model states in CPU.
Args:
model: The model to copy state from
megatron_amp_O2: Whether to convert to AMP O2 format
Returns:
CPU state dict with all tensors detached and copied
"""
Import
from nemo_aligner.utils.utils import retrieve_model_state_dict_in_cpu
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Model whose state dict to copy |
| megatron_amp_O2 | bool | No | Convert to AMP O2 format (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| cpu_dict | dict | Complete state dict with all tensors on CPU |
Usage Examples
from nemo_aligner.utils.utils import retrieve_model_state_dict_in_cpu
# Create reference policy snapshot before DPO training
ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
model,
megatron_amp_O2=cfg.model.megatron_amp_O2,
)
# Later, during forward pass, model swaps in ref weights:
# model.ref_policy_state_dict = ref_policy_state_dict
Related Pages
- Principle:NVIDIA_NeMo_Aligner_DPO_Reference_Policy_Management
- Environment:NVIDIA_NeMo_Aligner_NeMo_Framework_GPU_Environment