Implementation:Microsoft Onnxruntime Checkpoint Save Load
Appearance
| Field | Value |
|---|---|
| Implementation Name | Checkpoint_Save_Load |
| Overview | Saving and loading training checkpoints for fault tolerance and training resumption in distributed training. |
| Type | API Doc |
| Language | C++ |
| Domains | Distributed_Training, Training_Infrastructure |
| Source Repository | microsoft/onnxruntime |
| Last Updated | 2026-02-10 |
Overview
Saving and loading training checkpoints for fault tolerance and training resumption in distributed training. Two checkpoint APIs are provided: the TrainingRunner methods for the legacy distributed pipeline, and the onnxruntime::training::api functions for the modern training API.
API
TrainingRunner Checkpoint API
// From training_runner.h (private methods called internally)
Status TrainingRunner::SaveCheckpoint(const PathString& checkpoint_path);
Status TrainingRunner::LoadCheckpoint(const PathString& checkpoint_path);
Training API Checkpoint Functions
// From checkpoint.h (namespace onnxruntime::training::api)
Status SaveCheckpoint(const CheckpointState& state,
const PathString& checkpoint_path,
const bool include_optimizer_state);
Status LoadCheckpoint(const PathString& checkpoint_path,
CheckpointState& checkpoint_state);
Source Code Reference
- Repository: microsoft/onnxruntime
- TrainingRunner declarations: orttraining/orttraining/models/runner/training_runner.h:L252-253
- Training API SaveCheckpoint: orttraining/orttraining/training_api/checkpoint.h:L50-51
- Training API LoadCheckpoint: orttraining/orttraining/training_api/checkpoint.h:L82-83
- Implementation: orttraining/orttraining/training_api/checkpoint.cc
Key Parameters
| Parameter | Type | Required | Description |
|---|---|---|---|
| checkpoint_path | PathString | Yes | File system path for saving/loading the checkpoint file |
| state | CheckpointState | Yes (Save) | Training state containing module, optimizer, and user properties |
| checkpoint_state | CheckpointState& | Yes (Load) | Output parameter populated with loaded training state |
| include_optimizer_state | bool | Yes (Save) | Whether to include optimizer momentum and state in the checkpoint |
I/O Contract
| Direction | Name | Type | Description |
|---|---|---|---|
| Input (Save) | state | CheckpointState | Current training state (model params, optimizer state, properties) |
| Input (Save) | checkpoint_path | PathString | Destination file path for the checkpoint |
| Input (Save) | include_optimizer_state | bool | Whether to include optimizer momentum in the checkpoint |
| Input (Load) | checkpoint_path | PathString | Source file path of an existing checkpoint |
| Output (Load) | checkpoint_state | CheckpointState | Restored training state |
| Output | Status | common::Status | OK on success, error on failure |
Usage Examples
Saving a Checkpoint (Training API)
#include "orttraining/training_api/checkpoint.h"
using namespace onnxruntime::training::api;
CheckpointState state;
// ... populate state from training module and optimizer ...
// Save checkpoint with optimizer state
ORT_THROW_IF_ERROR(SaveCheckpoint(
state,
ORT_TSTR("/checkpoints/step_1000.ckpt"),
true)); // include_optimizer_state
// Save checkpoint without optimizer state (smaller file, for inference only)
ORT_THROW_IF_ERROR(SaveCheckpoint(
state,
ORT_TSTR("/checkpoints/step_1000_weights_only.ckpt"),
false));
Loading a Checkpoint (Training API)
#include "orttraining/training_api/checkpoint.h"
using namespace onnxruntime::training::api;
CheckpointState loaded_state;
ORT_THROW_IF_ERROR(LoadCheckpoint(
ORT_TSTR("/checkpoints/step_1000.ckpt"),
loaded_state));
// Access loaded state components
auto& module_state = loaded_state.module_checkpoint_state;
auto& optimizer_state = loaded_state.optimizer_checkpoint_state;
auto& properties = loaded_state.property_bag;
Automatic Checkpointing in TrainingRunner
// Checkpointing is configured via Parameters
TrainingRunner::Parameters params;
params.checkpoints_dir = ORT_TSTR("/shared/checkpoints/");
params.checkpoint_period = 500; // Save every 500 weight-update steps
params.max_num_checkpoints = 3; // Keep at most 3 checkpoint files
// On Initialize(), the runner automatically loads the latest checkpoint
// During Run(), checkpoints are saved at the configured interval
CheckpointState Structure
struct CheckpointState {
ModuleCheckpointState module_checkpoint_state; // Model parameters
OptimizerCheckpointState optimizer_checkpoint_state; // Optimizer momentum, etc.
PropertyBag property_bag; // User-defined properties
bool has_external_data = false; // External data file flag
};
Key Details
- Checkpoints are stored in flatbuffer format (schema: ort_training_checkpoint.fbs).
- For large models exceeding 1.8 GB, external data files are used alongside the main checkpoint file.
- The CheckpointRegistry in TrainingRunner manages checkpoint rotation, automatically removing the oldest when max_num_checkpoints is reached.
- LoadCheckpoint() must be called after session_.Initialize() because the session graph must be finalized before parameter values can be loaded.
- The TrainingRunner::Initialize() method automatically attempts to load the latest checkpoint from checkpoints_dir if no specific checkpoint_to_load_path is provided.
- An additional LoadCheckpointFromBuffer() function supports loading from in-memory bytes (gsl::span<const uint8_t>).
- SaveCheckpoint also has an overload that accepts TensorProto spans for saving ONNX initializers directly.
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment