Implementation:Eric mitchell Direct preference optimization BasicTrainer Save
| Knowledge Sources | |
|---|---|
| Domains | Checkpointing, Training |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for saving policy, optimizer, and scheduler state to disk provided by the direct-preference-optimization repository.
Description
The BasicTrainer.save method saves three checkpoint files to disk: policy.pt, optimizer.pt, and scheduler.pt. Each file contains a dictionary with the step index, state dictionary, and optional evaluation metrics. The companion write_state_dict method handles the actual file I/O.
Usage
Called automatically during the training loop at evaluation checkpoints, and once at the end of training. The output policy.pt file is used as input to DPO training via the config.model.archive parameter.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 401-428 (write_state_dict at L401-413, save at L415-428)
Signature
class BasicTrainer(object):
def write_state_dict(
self,
step: int,
state: Dict[str, torch.Tensor],
metrics: Dict,
filename: str,
dir_name: Optional[str] = None,
) -> None:
"""Write a checkpoint to disk."""
def save(
self,
output_dir: Optional[str] = None,
metrics: Optional[Dict] = None,
) -> None:
"""Save policy, optimizer, and scheduler state to disk."""
Import
from trainers import BasicTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| output_dir | Optional[str] | No | Checkpoint directory (defaults to {run_dir}/LATEST) |
| metrics | Optional[Dict] | No | Evaluation metrics to store alongside checkpoint |
Outputs
| Name | Type | Description |
|---|---|---|
| policy.pt | File | Dict with step_idx (int), state (OrderedDict of model weights), metrics (Dict) |
| optimizer.pt | File | Dict with step_idx, state (optimizer state dict), metrics |
| scheduler.pt | File | Dict with step_idx, state (scheduler state dict), metrics |
Usage Examples
Saving During Training
# Called internally by BasicTrainer.train() at eval checkpoints
output_dir = os.path.join(run_dir, f'step-{example_counter}')
trainer.save(output_dir, mean_eval_metrics)
Saving Final Checkpoint
# Called after training completes (in worker_main)
trainer.train()
trainer.save() # Saves to {run_dir}/LATEST/
Loading a Saved Checkpoint
import torch
# Load SFT checkpoint for DPO training
state_dict = torch.load("path/to/LATEST/policy.pt", map_location='cpu')
step = state_dict['step_idx']
metrics = state_dict['metrics']
policy.load_state_dict(state_dict['state'])