Principle:Microsoft Onnxruntime Distributed Checkpoint Management
| Field | Value |
|---|---|
| Principle Name | Distributed_Checkpoint_Management |
| Overview | Periodic saving and loading of training state for fault tolerance in long-running distributed training. |
| Category | API Doc |
| Domains | Distributed_Training, Training_Infrastructure |
| Source Repository | microsoft/onnxruntime |
| Last Updated | 2026-02-10 |
Overview
Periodic saving and loading of training state for fault tolerance in long-running distributed training. Distributed checkpoint management saves the complete training state at configurable intervals and enables resumption from the latest checkpoint after failures.
Description
Distributed checkpoint management saves the complete training state (model parameters, optimizer momentum, training step, learning rate) at configurable intervals. On failure, training can resume from the latest checkpoint, minimizing lost computation.
The checkpoint system operates at two levels:
TrainingRunner Checkpoint (Legacy API)
The TrainingRunner provides checkpoint methods for the distributed training pipeline:
- SaveCheckpoint(checkpoint_path): Saves the current training state including model weights, optimizer state, training step counter, and learning rate.
- LoadCheckpoint(checkpoint_path): Restores training state from a previously saved checkpoint.
- CheckpointRegistry: Manages checkpoint files in a directory, tracking the latest checkpoint and enforcing max_num_checkpoints retention.
The checkpoint period is configured via Parameters::checkpoint_period (in weight-update steps). When set to 0, no checkpoints are saved.
Training API Checkpoint (Modern API)
The onnxruntime::training::api namespace provides a more general checkpoint API:
- SaveCheckpoint(state, path, include_optimizer_state): Saves a CheckpointState object containing module state, optimizer state, and user-defined properties to a flatbuffer file.
- LoadCheckpoint(path, checkpoint_state): Loads a CheckpointState from a checkpoint file.
The CheckpointState struct contains:
- module_checkpoint_state: Model trainable and non-trainable parameters.
- optimizer_checkpoint_state: Optimizer state (learning rate, step, momentum values).
- property_bag: User-defined properties (e.g., epoch number, best score).
Checkpoint File Format
Checkpoints are stored as flatbuffer files (schema defined in ort_training_checkpoint.fbs). For large models, external data files may be used when the checkpoint exceeds a configurable threshold (default: 1.8 GB).
Theoretical Basis
Fault tolerance in distributed training is critical as the probability of any single node failing increases with scale. Periodic checkpointing bounds the maximum lost work to one checkpoint interval.
Key theoretical considerations:
- Failure probability: For N nodes each with failure probability p per hour, the probability of at least one failure is 1 - (1-p)^N, which grows rapidly with N.
- Checkpoint frequency trade-off: More frequent checkpoints reduce lost work but increase I/O overhead. The optimal frequency balances these costs.
- Consistent state: All processes must checkpoint at the same training step to ensure consistent model state across the distributed system.
- Storage management: The max_num_checkpoints parameter prevents disk space exhaustion during long training runs by removing old checkpoints.
Usage
Checkpoint management is configured during training setup and operates automatically during the training loop:
- Set checkpoints_dir to a shared filesystem accessible by all ranks.
- Set checkpoint_period to the desired interval (in weight-update steps).
- Set max_num_checkpoints to limit storage usage.
- Optionally set checkpoint_to_load_path to resume from a specific checkpoint.
- The training loop automatically saves checkpoints at the configured interval.
- On restart, Initialize() automatically loads the latest checkpoint from checkpoints_dir.