Implementation:Microsoft Onnxruntime CheckpointState Load
Overview
Deserializes a flatbuffers-encoded checkpoint file into an in-memory CheckpointState object containing model parameters, optimizer states, and user properties for training initialization or resumption.
Metadata
| Field | Value |
|---|---|
| Implementation Name | CheckpointState_Load |
| Type | API Doc |
| Language | C++ and Python |
| API | Python: CheckpointState.load_checkpoint(path_to_checkpoint), C++: LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_state) -> Status
|
| Import | from onnxruntime.training.api import CheckpointState
|
| Domain | On_Device_Training, Training_Infrastructure |
| Repository | microsoft/onnxruntime |
| Source Reference | orttraining/orttraining/training_api/checkpoint.cc:L957 (C++), docs/python/on_device_training/training_api.rst:L20 (Python) |
| Last Updated | 2026-02-10 |
Description
The checkpoint loading implementation reads a flatbuffers-encoded file from disk, deserializes it, and populates a CheckpointState object. The C++ implementation reads the file into a byte buffer, then delegates to the FlatBuffers deserialization logic. The Python API wraps this C++ implementation to provide a convenient class method.
The C++ LoadCheckpoint function enforces a little-endian check before proceeding with deserialization. A separate LoadCheckpointFromBuffer overload accepts a pre-loaded byte buffer instead of a file path.
API Signature
Python
from onnxruntime.training.api import CheckpointState
state = CheckpointState.load_checkpoint(path_to_checkpoint)
C++
namespace onnxruntime::training::api {
Status LoadCheckpoint(const PathString& checkpoint_path,
CheckpointState& checkpoint_state);
Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes,
CheckpointState& checkpoint_state);
} // namespace onnxruntime::training::api
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| checkpoint_path (C++) / path_to_checkpoint (Python) | PathString / str |
File system path to the checkpoint file to load |
| checkpoint_state (C++) | CheckpointState& |
Output parameter populated with deserialized state (C++ only; Python returns the object) |
| checkpoint_bytes (C++ buffer variant) | gsl::span<const uint8_t> |
Pre-loaded byte buffer containing checkpoint data |
I/O Contract
| Direction | Type | Description |
|---|---|---|
| Input | Checkpoint file path or byte buffer | FlatBuffers-encoded checkpoint file |
| Output | CheckpointState |
Contains ModuleCheckpointState (named_parameters map), OptimizerCheckpointState (group_named_optimizer_states), and PropertyBag (user-defined metadata)
|
Code Reference
From orttraining/orttraining/training_api/checkpoint.cc:
Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_states) {
ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");
InlinedVector<uint8_t> checkpoint_bytes;
ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
return load::ToCheckpointState(checkpoint_bytes, checkpoint_states, checkpoint_path);
}
Usage Example
Python
from onnxruntime.training.api import CheckpointState
# Load from a file generated by artifact generation
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")
# Access parameters
for name, param in state.parameters.items():
print(f"{name}: requires_grad={param.requires_grad}, shape={param.data.shape}")
# Access user properties
if "epoch" in state.properties:
print(f"Resuming from epoch {state.properties['epoch']}")
C++
#include "orttraining/training_api/checkpoint.h"
using namespace onnxruntime::training::api;
CheckpointState checkpoint_state;
Status status = LoadCheckpoint(ORT_TSTR("training_artifacts/checkpoint"), checkpoint_state);
if (!status.IsOK()) {
// Handle error
}
// Access module state
auto& named_params = checkpoint_state.module_checkpoint_state.named_parameters;
for (const auto& [name, param] : named_params) {
std::cout << "Parameter: " << name
<< ", requires_grad: " << param->RequiresGrad() << std::endl;
}
Implements
Principle:Microsoft_Onnxruntime_Checkpoint_Loading
Related Pages
- Generate Artifacts -- Produces the checkpoint file loaded here
- SaveCheckpoint -- The inverse operation for serializing state
- Module Optimizer Scheduler Init -- Uses the loaded state to initialize training components