Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Deepspeedai DeepSpeed Checkpoint Management

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, Checkpointing, Fault_Tolerance
Last Updated 2026-02-09 00:00 GMT

Overview

A distributed checkpointing system that saves and restores model parameters, optimizer states, and training metadata across ZeRO-partitioned ranks.

Description

Checkpoint Management handles the complexity of saving and loading distributed training state when parameters and optimizer states are partitioned across multiple GPUs (ZeRO stages 1-3). The system addresses several challenges:

  • Distributed save: Each rank saves its local shard of parameters and optimizer states. All ranks must participate in the save operation to ensure consistency.
  • Distributed load: Each rank reads its shard and restores local state. The checkpoint tag (or "latest" file) identifies which checkpoint to load.
  • Client state: User-defined training state (epoch number, global step, custom metrics) can be saved alongside model state.
  • FP32 consolidation: Utility to convert ZeRO-sharded checkpoints to standard PyTorch state dictionaries for deployment or model sharing.
  • Universal checkpointing: Enables resharding across different GPU counts -- a checkpoint saved on N GPUs can be loaded on M GPUs.
  • Frozen parameter handling: Optional exclusion of frozen (non-trainable) parameters from checkpoints to save disk space.

Usage

Call engine.save_checkpoint() periodically during training to save state. Call engine.load_checkpoint() at the beginning of training to resume from a checkpoint. Use get_fp32_state_dict_from_zero_checkpoint() to consolidate sharded checkpoints for deployment.

Important: All ranks must call save_checkpoint() -- not just rank 0 -- because each rank saves its own shard of the distributed state.

Theoretical Basis

Distributed checkpointing must maintain consistency across ranks. The complexity depends on the ZeRO stage:

  • ZeRO Stage 0: Standard data-parallel checkpointing -- all ranks have identical model state, so only rank 0 needs to save the model (but all save optimizer states).
  • ZeRO Stage 1: Optimizer states are partitioned. Each rank saves its 1/N partition of optimizer states alongside the full model state.
  • ZeRO Stage 2: Optimizer states and gradients are partitioned. Similar to Stage 1 for checkpointing since gradients are transient.
  • ZeRO Stage 3: Parameters are also partitioned. Each rank saves its 1/N partition of both parameters and optimizer states. Loading requires gathering all shards.

Consolidation: Gathering all shards from N checkpoint files to produce a single fp32 state dict. This requires:

  1. Reading the parameter partition metadata (shapes, offsets) from each rank's checkpoint
  2. Reconstructing full tensors by concatenating partitioned slices
  3. Converting from fp16/bf16 to fp32 if applicable

Pseudo-code:

# Abstract distributed checkpoint save
def save_checkpoint(save_dir, tag):
    barrier()  # synchronize all ranks
    if rank == 0:
        makedirs(save_dir)
    barrier()

    # Each rank saves its local shard
    save(model_shard, optimizer_shard, scheduler_state, client_state,
         path=f"{save_dir}/{tag}/rank_{rank}.pt")

    if save_latest:
        write_file(f"{save_dir}/latest", tag)

# Abstract checkpoint consolidation
def consolidate_to_fp32(checkpoint_dir):
    state_dict = {}
    for rank in range(world_size):
        shard = load(f"{checkpoint_dir}/rank_{rank}.pt")
        for param_name, partition in shard.items():
            state_dict[param_name].append(partition)
    return {k: concat(v).float() for k, v in state_dict.items()}

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment