Principle:Deepspeedai DeepSpeed Checkpoint Management
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Checkpointing, Fault_Tolerance |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A distributed checkpointing system that saves and restores model parameters, optimizer states, and training metadata across ZeRO-partitioned ranks.
Description
Checkpoint Management handles the complexity of saving and loading distributed training state when parameters and optimizer states are partitioned across multiple GPUs (ZeRO stages 1-3). The system addresses several challenges:
- Distributed save: Each rank saves its local shard of parameters and optimizer states. All ranks must participate in the save operation to ensure consistency.
- Distributed load: Each rank reads its shard and restores local state. The checkpoint tag (or "latest" file) identifies which checkpoint to load.
- Client state: User-defined training state (epoch number, global step, custom metrics) can be saved alongside model state.
- FP32 consolidation: Utility to convert ZeRO-sharded checkpoints to standard PyTorch state dictionaries for deployment or model sharing.
- Universal checkpointing: Enables resharding across different GPU counts -- a checkpoint saved on N GPUs can be loaded on M GPUs.
- Frozen parameter handling: Optional exclusion of frozen (non-trainable) parameters from checkpoints to save disk space.
Usage
Call engine.save_checkpoint() periodically during training to save state. Call engine.load_checkpoint() at the beginning of training to resume from a checkpoint. Use get_fp32_state_dict_from_zero_checkpoint() to consolidate sharded checkpoints for deployment.
Important: All ranks must call save_checkpoint() -- not just rank 0 -- because each rank saves its own shard of the distributed state.
Theoretical Basis
Distributed checkpointing must maintain consistency across ranks. The complexity depends on the ZeRO stage:
- ZeRO Stage 0: Standard data-parallel checkpointing -- all ranks have identical model state, so only rank 0 needs to save the model (but all save optimizer states).
- ZeRO Stage 1: Optimizer states are partitioned. Each rank saves its 1/N partition of optimizer states alongside the full model state.
- ZeRO Stage 2: Optimizer states and gradients are partitioned. Similar to Stage 1 for checkpointing since gradients are transient.
- ZeRO Stage 3: Parameters are also partitioned. Each rank saves its 1/N partition of both parameters and optimizer states. Loading requires gathering all shards.
Consolidation: Gathering all shards from N checkpoint files to produce a single fp32 state dict. This requires:
- Reading the parameter partition metadata (shapes, offsets) from each rank's checkpoint
- Reconstructing full tensors by concatenating partitioned slices
- Converting from fp16/bf16 to fp32 if applicable
Pseudo-code:
# Abstract distributed checkpoint save
def save_checkpoint(save_dir, tag):
barrier() # synchronize all ranks
if rank == 0:
makedirs(save_dir)
barrier()
# Each rank saves its local shard
save(model_shard, optimizer_shard, scheduler_state, client_state,
path=f"{save_dir}/{tag}/rank_{rank}.pt")
if save_latest:
write_file(f"{save_dir}/latest", tag)
# Abstract checkpoint consolidation
def consolidate_to_fp32(checkpoint_dir):
state_dict = {}
for rank in range(world_size):
shard = load(f"{checkpoint_dir}/rank_{rank}.pt")
for param_name, partition in shard.items():
state_dict[param_name].append(partition)
return {k: concat(v).float() for k, v in state_dict.items()}