Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime Checkpoint Save Load

From Leeroopedia


Field Value
Implementation Name Checkpoint_Save_Load
Overview Saving and loading training checkpoints for fault tolerance and training resumption in distributed training.
Type API Doc
Language C++
Domains Distributed_Training, Training_Infrastructure
Source Repository microsoft/onnxruntime
Last Updated 2026-02-10

Overview

Saving and loading training checkpoints for fault tolerance and training resumption in distributed training. Two checkpoint APIs are provided: the TrainingRunner methods for the legacy distributed pipeline, and the onnxruntime::training::api functions for the modern training API.

API

TrainingRunner Checkpoint API

// From training_runner.h (private methods called internally)
Status TrainingRunner::SaveCheckpoint(const PathString& checkpoint_path);
Status TrainingRunner::LoadCheckpoint(const PathString& checkpoint_path);

Training API Checkpoint Functions

// From checkpoint.h (namespace onnxruntime::training::api)
Status SaveCheckpoint(const CheckpointState& state,
                      const PathString& checkpoint_path,
                      const bool include_optimizer_state);

Status LoadCheckpoint(const PathString& checkpoint_path,
                      CheckpointState& checkpoint_state);

Source Code Reference

Key Parameters

Parameter Type Required Description
checkpoint_path PathString Yes File system path for saving/loading the checkpoint file
state CheckpointState Yes (Save) Training state containing module, optimizer, and user properties
checkpoint_state CheckpointState& Yes (Load) Output parameter populated with loaded training state
include_optimizer_state bool Yes (Save) Whether to include optimizer momentum and state in the checkpoint

I/O Contract

Direction Name Type Description
Input (Save) state CheckpointState Current training state (model params, optimizer state, properties)
Input (Save) checkpoint_path PathString Destination file path for the checkpoint
Input (Save) include_optimizer_state bool Whether to include optimizer momentum in the checkpoint
Input (Load) checkpoint_path PathString Source file path of an existing checkpoint
Output (Load) checkpoint_state CheckpointState Restored training state
Output Status common::Status OK on success, error on failure

Usage Examples

Saving a Checkpoint (Training API)

#include "orttraining/training_api/checkpoint.h"

using namespace onnxruntime::training::api;

CheckpointState state;
// ... populate state from training module and optimizer ...

// Save checkpoint with optimizer state
ORT_THROW_IF_ERROR(SaveCheckpoint(
    state,
    ORT_TSTR("/checkpoints/step_1000.ckpt"),
    true));  // include_optimizer_state

// Save checkpoint without optimizer state (smaller file, for inference only)
ORT_THROW_IF_ERROR(SaveCheckpoint(
    state,
    ORT_TSTR("/checkpoints/step_1000_weights_only.ckpt"),
    false));

Loading a Checkpoint (Training API)

#include "orttraining/training_api/checkpoint.h"

using namespace onnxruntime::training::api;

CheckpointState loaded_state;
ORT_THROW_IF_ERROR(LoadCheckpoint(
    ORT_TSTR("/checkpoints/step_1000.ckpt"),
    loaded_state));

// Access loaded state components
auto& module_state = loaded_state.module_checkpoint_state;
auto& optimizer_state = loaded_state.optimizer_checkpoint_state;
auto& properties = loaded_state.property_bag;

Automatic Checkpointing in TrainingRunner

// Checkpointing is configured via Parameters
TrainingRunner::Parameters params;
params.checkpoints_dir = ORT_TSTR("/shared/checkpoints/");
params.checkpoint_period = 500;       // Save every 500 weight-update steps
params.max_num_checkpoints = 3;       // Keep at most 3 checkpoint files

// On Initialize(), the runner automatically loads the latest checkpoint
// During Run(), checkpoints are saved at the configured interval

CheckpointState Structure

struct CheckpointState {
    ModuleCheckpointState module_checkpoint_state;      // Model parameters
    OptimizerCheckpointState optimizer_checkpoint_state; // Optimizer momentum, etc.
    PropertyBag property_bag;                           // User-defined properties
    bool has_external_data = false;                     // External data file flag
};

Key Details

  • Checkpoints are stored in flatbuffer format (schema: ort_training_checkpoint.fbs).
  • For large models exceeding 1.8 GB, external data files are used alongside the main checkpoint file.
  • The CheckpointRegistry in TrainingRunner manages checkpoint rotation, automatically removing the oldest when max_num_checkpoints is reached.
  • LoadCheckpoint() must be called after session_.Initialize() because the session graph must be finalized before parameter values can be loaded.
  • The TrainingRunner::Initialize() method automatically attempts to load the latest checkpoint from checkpoints_dir if no specific checkpoint_to_load_path is provided.
  • An additional LoadCheckpointFromBuffer() function supports loading from in-memory bytes (gsl::span<const uint8_t>).
  • SaveCheckpoint also has an overload that accepts TensorProto spans for saving ONNX initializers directly.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment