Implementation:Eric mitchell Direct preference optimization Torch Load State Dict
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Transfer_Learning, Checkpointing |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Wrapper for torch.load and load_state_dict as used in this repository for loading SFT checkpoints into policy and reference models.
Description
This repository uses torch.load with map_location='cpu' to load SFT checkpoint files, then applies the state dict to both the policy and reference models via load_state_dict. The checkpoint format (produced by BasicTrainer.save) contains step_idx, state, and metrics.
Usage
Used in train.py when config.model.archive is not None. This occurs at the beginning of DPO training to initialize both models from the SFT checkpoint.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: train.py
- Lines: 96-103
Signature
# Checkpoint loading sequence (train.py:L96-103)
if config.model.archive is not None:
state_dict = torch.load(config.model.archive, map_location='cpu')
step, metrics = state_dict['step_idx'], state_dict['metrics']
print(f'loading pre-trained weights at step {step} from {config.model.archive} '
f'with metrics {json.dumps(metrics, indent=2)}')
policy.load_state_dict(state_dict['state'])
if config.loss.name in {'dpo', 'ipo'}:
reference_model.load_state_dict(state_dict['state'])
print('loaded pre-trained weights')
Import
import torch
import json
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config.model.archive | Optional[str] | No | Path to SFT checkpoint file (e.g., "path/to/LATEST/policy.pt"); if None, step is skipped |
| policy | nn.Module | Yes | Pre-loaded model to receive SFT weights |
| reference_model | Optional[nn.Module] | DPO only | Pre-loaded reference model (also receives SFT weights) |
Outputs
| Name | Type | Description |
|---|---|---|
| policy (modified) | nn.Module | Model with SFT weights loaded in-place |
| reference_model (modified) | Optional[nn.Module] | Reference model with same SFT weights (for DPO) |
Usage Examples
Loading SFT Checkpoint for DPO
import torch
import json
archive_path = "/path/to/sft_run/LATEST/policy.pt"
state_dict = torch.load(archive_path, map_location='cpu')
print(f"Loading from step {state_dict['step_idx']}")
print(f"Metrics: {json.dumps(state_dict['metrics'], indent=2)}")
policy.load_state_dict(state_dict['state'])
reference_model.load_state_dict(state_dict['state'])
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment