Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Torch Load State Dict

From Leeroopedia


Knowledge Sources
Domains Transfer_Learning, Checkpointing
Last Updated 2026-02-08 02:00 GMT

Overview

Wrapper for torch.load and load_state_dict as used in this repository for loading SFT checkpoints into policy and reference models.

Description

This repository uses torch.load with map_location='cpu' to load SFT checkpoint files, then applies the state dict to both the policy and reference models via load_state_dict. The checkpoint format (produced by BasicTrainer.save) contains step_idx, state, and metrics.

Usage

Used in train.py when config.model.archive is not None. This occurs at the beginning of DPO training to initialize both models from the SFT checkpoint.

Code Reference

Source Location

Signature

# Checkpoint loading sequence (train.py:L96-103)
if config.model.archive is not None:
    state_dict = torch.load(config.model.archive, map_location='cpu')
    step, metrics = state_dict['step_idx'], state_dict['metrics']
    print(f'loading pre-trained weights at step {step} from {config.model.archive} '
          f'with metrics {json.dumps(metrics, indent=2)}')
    policy.load_state_dict(state_dict['state'])
    if config.loss.name in {'dpo', 'ipo'}:
        reference_model.load_state_dict(state_dict['state'])
    print('loaded pre-trained weights')

Import

import torch
import json

I/O Contract

Inputs

Name Type Required Description
config.model.archive Optional[str] No Path to SFT checkpoint file (e.g., "path/to/LATEST/policy.pt"); if None, step is skipped
policy nn.Module Yes Pre-loaded model to receive SFT weights
reference_model Optional[nn.Module] DPO only Pre-loaded reference model (also receives SFT weights)

Outputs

Name Type Description
policy (modified) nn.Module Model with SFT weights loaded in-place
reference_model (modified) Optional[nn.Module] Reference model with same SFT weights (for DPO)

Usage Examples

Loading SFT Checkpoint for DPO

import torch
import json

archive_path = "/path/to/sft_run/LATEST/policy.pt"
state_dict = torch.load(archive_path, map_location='cpu')

print(f"Loading from step {state_dict['step_idx']}")
print(f"Metrics: {json.dumps(state_dict['metrics'], indent=2)}")

policy.load_state_dict(state_dict['state'])
reference_model.load_state_dict(state_dict['state'])

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment