Principle:Microsoft Onnxruntime Checkpoint Loading
Overview
Restoring training state from a serialized checkpoint file to resume or initialize training.
Metadata
| Field | Value |
|---|---|
| Principle Name | Checkpoint_Loading |
| Category | API Doc |
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | orttraining/orttraining/training_api/checkpoint.cc:L957 (C++), docs/python/on_device_training/training_api.rst:L20 (Python)
|
| Last Updated | 2026-02-10 |
Description
Checkpoint loading deserializes a flatbuffers-encoded file containing model parameters, optimizer states, and user properties into an in-memory CheckpointState object. This enables training resumption from a saved state or initialization from generated artifacts.
The CheckpointState object is composed of three main components:
- ModuleCheckpointState -- Contains
named_parameters, a map from parameter names toParameterobjects. EachParameterholds the parameter tensor data and a flag indicating whether it requires gradient computation. - OptimizerCheckpointState -- Contains
group_named_optimizer_states, a map of optimizer group states. For AdamW, each parameter's optimizer state includes first-order and second-order momentum tensors (momentum0,momentum1), along with the current step count and learning rate. - PropertyBag -- A key-value store for user-defined metadata such as the current epoch number, best validation score, or custom training metrics.
The checkpoint file format uses FlatBuffers serialization, which enables efficient deserialization without parsing overhead. When checkpoint data exceeds size thresholds, an external data file is created alongside the main checkpoint file.
Theoretical Basis
Checkpointing is essential for fault tolerance and incremental training. The checkpoint contains the complete training state: parameter values (ModuleCheckpointState), optimizer momentum states (OptimizerCheckpointState), and metadata (PropertyBag).
Key theoretical aspects include:
- Fault Tolerance -- By periodically saving checkpoints, training can be resumed from the last saved state after hardware failures or interruptions, avoiding the need to restart from scratch.
- Transfer Learning -- A checkpoint from one training run can serve as the initialization point for a different task, enabling knowledge transfer between models.
- Nominal Checkpoints -- ONNX Runtime supports nominal checkpoints that contain parameter shapes and types but not full tensor data. This allows creating lightweight checkpoints for initialization purposes, with actual parameter values loaded later via
CopyBufferToParameters. - Endianness Constraint -- The flatbuffers format used by ORT training checkpoints only supports little-endian machines. This is enforced at both save and load time.
Usage
Loading a checkpoint is typically the first step when setting up a training session:
from onnxruntime.training.api import CheckpointState
# Load from a checkpoint file
state = CheckpointState.load_checkpoint("output_artifacts/checkpoint")
# Access model parameters
for name, param in state.parameters.items():
print(f"Parameter: {name}, requires_grad: {param.requires_grad}")
# Access user-defined properties
if "epoch" in state.properties:
current_epoch = state.properties["epoch"]
In C++, the equivalent operation is:
CheckpointState checkpoint_state;
Status status = LoadCheckpoint(checkpoint_path, checkpoint_state);
// checkpoint_state now contains module state, optimizer state, and properties
Implemented By
Implementation:Microsoft_Onnxruntime_CheckpointState_Load
Related Pages
- Training Artifact Generation -- Produces the initial checkpoint file
- Training Component Assembly -- Uses the loaded checkpoint state to create training components
- Checkpoint Saving -- The inverse operation of checkpoint loading