Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization BasicTrainer Save

From Leeroopedia


Knowledge Sources
Domains Checkpointing, Training
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for saving policy, optimizer, and scheduler state to disk provided by the direct-preference-optimization repository.

Description

The BasicTrainer.save method saves three checkpoint files to disk: policy.pt, optimizer.pt, and scheduler.pt. Each file contains a dictionary with the step index, state dictionary, and optional evaluation metrics. The companion write_state_dict method handles the actual file I/O.

Usage

Called automatically during the training loop at evaluation checkpoints, and once at the end of training. The output policy.pt file is used as input to DPO training via the config.model.archive parameter.

Code Reference

Source Location

Signature

class BasicTrainer(object):
    def write_state_dict(
        self,
        step: int,
        state: Dict[str, torch.Tensor],
        metrics: Dict,
        filename: str,
        dir_name: Optional[str] = None,
    ) -> None:
        """Write a checkpoint to disk."""

    def save(
        self,
        output_dir: Optional[str] = None,
        metrics: Optional[Dict] = None,
    ) -> None:
        """Save policy, optimizer, and scheduler state to disk."""

Import

from trainers import BasicTrainer

I/O Contract

Inputs

Name Type Required Description
output_dir Optional[str] No Checkpoint directory (defaults to {run_dir}/LATEST)
metrics Optional[Dict] No Evaluation metrics to store alongside checkpoint

Outputs

Name Type Description
policy.pt File Dict with step_idx (int), state (OrderedDict of model weights), metrics (Dict)
optimizer.pt File Dict with step_idx, state (optimizer state dict), metrics
scheduler.pt File Dict with step_idx, state (scheduler state dict), metrics

Usage Examples

Saving During Training

# Called internally by BasicTrainer.train() at eval checkpoints
output_dir = os.path.join(run_dir, f'step-{example_counter}')
trainer.save(output_dir, mean_eval_metrics)

Saving Final Checkpoint

# Called after training completes (in worker_main)
trainer.train()
trainer.save()  # Saves to {run_dir}/LATEST/

Loading a Saved Checkpoint

import torch

# Load SFT checkpoint for DPO training
state_dict = torch.load("path/to/LATEST/policy.pt", map_location='cpu')
step = state_dict['step_idx']
metrics = state_dict['metrics']
policy.load_state_dict(state_dict['state'])

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment