Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CheckpointState Load

From Leeroopedia


Overview

Deserializes a flatbuffers-encoded checkpoint file into an in-memory CheckpointState object containing model parameters, optimizer states, and user properties for training initialization or resumption.

Metadata

Field Value
Implementation Name CheckpointState_Load
Type API Doc
Language C++ and Python
API Python: CheckpointState.load_checkpoint(path_to_checkpoint), C++: LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_state) -> Status
Import from onnxruntime.training.api import CheckpointState
Domain On_Device_Training, Training_Infrastructure
Repository microsoft/onnxruntime
Source Reference orttraining/orttraining/training_api/checkpoint.cc:L957 (C++), docs/python/on_device_training/training_api.rst:L20 (Python)
Last Updated 2026-02-10

Description

The checkpoint loading implementation reads a flatbuffers-encoded file from disk, deserializes it, and populates a CheckpointState object. The C++ implementation reads the file into a byte buffer, then delegates to the FlatBuffers deserialization logic. The Python API wraps this C++ implementation to provide a convenient class method.

The C++ LoadCheckpoint function enforces a little-endian check before proceeding with deserialization. A separate LoadCheckpointFromBuffer overload accepts a pre-loaded byte buffer instead of a file path.

API Signature

Python

from onnxruntime.training.api import CheckpointState

state = CheckpointState.load_checkpoint(path_to_checkpoint)

C++

namespace onnxruntime::training::api {

Status LoadCheckpoint(const PathString& checkpoint_path,
                      CheckpointState& checkpoint_state);

Status LoadCheckpointFromBuffer(gsl::span<const uint8_t> checkpoint_bytes,
                                CheckpointState& checkpoint_state);

}  // namespace onnxruntime::training::api

Key Parameters

Parameter Type Description
checkpoint_path (C++) / path_to_checkpoint (Python) PathString / str File system path to the checkpoint file to load
checkpoint_state (C++) CheckpointState& Output parameter populated with deserialized state (C++ only; Python returns the object)
checkpoint_bytes (C++ buffer variant) gsl::span<const uint8_t> Pre-loaded byte buffer containing checkpoint data

I/O Contract

Direction Type Description
Input Checkpoint file path or byte buffer FlatBuffers-encoded checkpoint file
Output CheckpointState Contains ModuleCheckpointState (named_parameters map), OptimizerCheckpointState (group_named_optimizer_states), and PropertyBag (user-defined metadata)

Code Reference

From orttraining/orttraining/training_api/checkpoint.cc:

Status LoadCheckpoint(const PathString& checkpoint_path, CheckpointState& checkpoint_states) {
  ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines");

  InlinedVector<uint8_t> checkpoint_bytes;
  ORT_RETURN_IF_ERROR(load::FromFile(checkpoint_path, checkpoint_bytes));
  return load::ToCheckpointState(checkpoint_bytes, checkpoint_states, checkpoint_path);
}

Usage Example

Python

from onnxruntime.training.api import CheckpointState

# Load from a file generated by artifact generation
state = CheckpointState.load_checkpoint("training_artifacts/checkpoint")

# Access parameters
for name, param in state.parameters.items():
    print(f"{name}: requires_grad={param.requires_grad}, shape={param.data.shape}")

# Access user properties
if "epoch" in state.properties:
    print(f"Resuming from epoch {state.properties['epoch']}")

C++

#include "orttraining/training_api/checkpoint.h"

using namespace onnxruntime::training::api;

CheckpointState checkpoint_state;
Status status = LoadCheckpoint(ORT_TSTR("training_artifacts/checkpoint"), checkpoint_state);
if (!status.IsOK()) {
    // Handle error
}

// Access module state
auto& named_params = checkpoint_state.module_checkpoint_state.named_parameters;
for (const auto& [name, param] : named_params) {
    std::cout << "Parameter: " << name
              << ", requires_grad: " << param->RequiresGrad() << std::endl;
}

Implements

Principle:Microsoft_Onnxruntime_Checkpoint_Loading

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment