Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Checkpoint
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Checkpointing, Fault_Tolerance |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for saving and loading distributed training checkpoints with ZeRO optimization provided by the DeepSpeed library.
Description
The DeepSpeed checkpoint system consists of three key functions:
DeepSpeedEngine.save_checkpoint() saves model state, optimizer state, scheduler state, and user-provided client state to disk. Each rank saves its local shard. The function:
- Ensures save_dir exists (rank 0 creates it, barrier synchronizes)
- Generates a tag from global_steps if not provided
- Validates tag consistency across all ranks
- Handles MoE (Mixture of Experts) layers with separate checkpoint logic
- Writes a "latest" file pointing to the most recent checkpoint tag
- Supports decoupled checkpoint commits via checkpoint engine
DeepSpeedEngine.load_checkpoint() restores all state from a checkpoint directory. It:
- Resolves the tag from the "latest" file if not explicitly provided
- Supports "latest_universal" for universal checkpoints
- Loads module, optimizer, and LR scheduler states based on boolean flags
- Supports strict and non-strict module loading
- Supports custom load functions via custom_load_fn
- Returns the load path and client state dictionary
get_fp32_state_dict_from_zero_checkpoint() consolidates ZeRO-sharded checkpoints into a single fp32 state dict. It:
- Reads all rank checkpoint files from the checkpoint directory
- Reconstructs full parameters from partitioned slices
- Converts to fp32 for deployment compatibility
- Supports lazy mode for memory-efficient loading
Usage
Call engine.save_checkpoint() from all ranks periodically during training. Call engine.load_checkpoint() from all ranks at the start of training to resume. Use get_fp32_state_dict_from_zero_checkpoint() offline or on a single process for model export.
Code Reference
Source Location
- Repository: DeepSpeed
- File: deepspeed/runtime/engine.py
- save_checkpoint: Lines 3695-3744
- load_checkpoint: Lines 3347-3394
- File: deepspeed/utils/zero_to_fp32.py
- get_fp32_state_dict_from_zero_checkpoint: Lines 533-570
Signature
def save_checkpoint(self, save_dir, tag=None, client_state={},
save_latest=True, exclude_frozen_parameters=False):
"""Save training checkpoint
Arguments:
save_dir: Required. Directory for saving the checkpoint
tag: Optional. Checkpoint tag used as a unique identifier;
global step is used if not provided.
client_state: Optional. State dictionary for client training states.
save_latest: Optional. Save a file 'latest' pointing to this checkpoint.
exclude_frozen_parameters: Optional. Exclude frozen parameters.
Important: all processes must call this method and not just rank 0.
"""
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True,
load_optimizer_states=True, load_lr_scheduler_states=True,
load_module_only=False, custom_load_fn=None):
"""Load training checkpoint
Returns:
A tuple of (load_path, client_state).
"""
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None,
exclude_frozen_parameters=False,
lazy_mode=False):
"""Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated
state_dict that can be loaded with load_state_dict().
Returns:
pytorch state_dict
"""
Import
# save/load via engine returned by deepspeed.initialize()
import deepspeed
engine, _, _, _ = deepspeed.initialize(model=model, config=config)
engine.save_checkpoint(...)
engine.load_checkpoint(...)
# FP32 consolidation utility
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
I/O Contract
Inputs (save_checkpoint)
| Name | Type | Required | Description |
|---|---|---|---|
| save_dir | str | Yes | Directory for saving the checkpoint |
| tag | str | No | Unique identifier for checkpoint; defaults to "global_step{N}" |
| client_state | dict | No | User-defined state dictionary (e.g., epoch, custom metrics) |
| save_latest | bool | No | Write a "latest" file pointing to this checkpoint (default: True) |
| exclude_frozen_parameters | bool | No | Exclude frozen (non-trainable) parameters from checkpoint (default: False) |
Inputs (load_checkpoint)
| Name | Type | Required | Description |
|---|---|---|---|
| load_dir | str | Yes | Directory to load the checkpoint from |
| tag | str | No | Checkpoint tag; reads from "latest" file if not provided |
| load_module_strict | bool | No | Strictly enforce key matching between module and checkpoint (default: True) |
| load_optimizer_states | bool | No | Load optimizer states from checkpoint (default: True) |
| load_lr_scheduler_states | bool | No | Load LR scheduler states from checkpoint (default: True) |
| load_module_only | bool | No | Load only model weights, not optimizer or scheduler (default: False) |
| custom_load_fn | Callable | No | Custom function for model loading |
Inputs (get_fp32_state_dict_from_zero_checkpoint)
| Name | Type | Required | Description |
|---|---|---|---|
| checkpoint_dir | str | Yes | Path to the checkpoint folder |
| tag | str | No | Checkpoint tag; reads from "latest" file if not provided |
| exclude_frozen_parameters | bool | No | Exclude frozen parameters from the consolidated state dict (default: False) |
| lazy_mode | bool | No | Return pseudo tensors for memory-efficient loading; call .contiguous() to materialize (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| (save_checkpoint) | None | Checkpoint files written to disk at save_dir/tag/ |
| (load_checkpoint) | Tuple[str, dict] | Tuple of (load_path, client_state_dict); load_path is None if loading failed |
| (get_fp32_state_dict) | dict | Python dictionary with full fp32 model state, compatible with model.load_state_dict() |
Usage Examples
import deepspeed
import torch.nn as nn
model = nn.Linear(1024, 10)
engine, optimizer, _, _ = deepspeed.initialize(
model=model, config="ds_config.json",
model_parameters=model.parameters()
)
# === Save checkpoint during training ===
# All ranks must call this (not just rank 0)
engine.save_checkpoint(
"checkpoints/",
tag="step_1000",
client_state={"epoch": 5, "best_loss": 0.42}
)
# === Resume training from checkpoint ===
load_path, client_state = engine.load_checkpoint(
"checkpoints/",
tag="step_1000"
)
if load_path:
print(f"Resumed from {load_path}")
print(f"Epoch: {client_state['epoch']}")
# === Load latest checkpoint (tag resolved automatically) ===
load_path, client_state = engine.load_checkpoint("checkpoints/")
# === Warm-start from checkpoint (model only, no optimizer) ===
engine.load_checkpoint(
"checkpoints/",
tag="step_1000",
load_module_only=True
)
# === Convert ZeRO checkpoint for deployment ===
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
state_dict = get_fp32_state_dict_from_zero_checkpoint(
"checkpoints/step_1000"
)
standalone_model = nn.Linear(1024, 10)
standalone_model.load_state_dict(state_dict)
# standalone_model is now a standard PyTorch model for inference