Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Checkpoint

From Leeroopedia
Revision as of 14:46, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_DeepSpeedEngine_Checkpoint.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Distributed_Training, Checkpointing, Fault_Tolerance
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for saving and loading distributed training checkpoints with ZeRO optimization provided by the DeepSpeed library.

Description

The DeepSpeed checkpoint system consists of three key functions:

DeepSpeedEngine.save_checkpoint() saves model state, optimizer state, scheduler state, and user-provided client state to disk. Each rank saves its local shard. The function:

  • Ensures save_dir exists (rank 0 creates it, barrier synchronizes)
  • Generates a tag from global_steps if not provided
  • Validates tag consistency across all ranks
  • Handles MoE (Mixture of Experts) layers with separate checkpoint logic
  • Writes a "latest" file pointing to the most recent checkpoint tag
  • Supports decoupled checkpoint commits via checkpoint engine

DeepSpeedEngine.load_checkpoint() restores all state from a checkpoint directory. It:

  • Resolves the tag from the "latest" file if not explicitly provided
  • Supports "latest_universal" for universal checkpoints
  • Loads module, optimizer, and LR scheduler states based on boolean flags
  • Supports strict and non-strict module loading
  • Supports custom load functions via custom_load_fn
  • Returns the load path and client state dictionary

get_fp32_state_dict_from_zero_checkpoint() consolidates ZeRO-sharded checkpoints into a single fp32 state dict. It:

  • Reads all rank checkpoint files from the checkpoint directory
  • Reconstructs full parameters from partitioned slices
  • Converts to fp32 for deployment compatibility
  • Supports lazy mode for memory-efficient loading

Usage

Call engine.save_checkpoint() from all ranks periodically during training. Call engine.load_checkpoint() from all ranks at the start of training to resume. Use get_fp32_state_dict_from_zero_checkpoint() offline or on a single process for model export.

Code Reference

Source Location

  • Repository: DeepSpeed
  • File: deepspeed/runtime/engine.py
    • save_checkpoint: Lines 3695-3744
    • load_checkpoint: Lines 3347-3394
  • File: deepspeed/utils/zero_to_fp32.py
    • get_fp32_state_dict_from_zero_checkpoint: Lines 533-570

Signature

def save_checkpoint(self, save_dir, tag=None, client_state={},
                    save_latest=True, exclude_frozen_parameters=False):
    """Save training checkpoint

    Arguments:
        save_dir: Required. Directory for saving the checkpoint
        tag: Optional. Checkpoint tag used as a unique identifier;
             global step is used if not provided.
        client_state: Optional. State dictionary for client training states.
        save_latest: Optional. Save a file 'latest' pointing to this checkpoint.
        exclude_frozen_parameters: Optional. Exclude frozen parameters.

    Important: all processes must call this method and not just rank 0.
    """
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True,
                    load_optimizer_states=True, load_lr_scheduler_states=True,
                    load_module_only=False, custom_load_fn=None):
    """Load training checkpoint

    Returns:
        A tuple of (load_path, client_state).
    """
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None,
                                              exclude_frozen_parameters=False,
                                              lazy_mode=False):
    """Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated
    state_dict that can be loaded with load_state_dict().

    Returns:
        pytorch state_dict
    """

Import

# save/load via engine returned by deepspeed.initialize()
import deepspeed
engine, _, _, _ = deepspeed.initialize(model=model, config=config)
engine.save_checkpoint(...)
engine.load_checkpoint(...)

# FP32 consolidation utility
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

I/O Contract

Inputs (save_checkpoint)

Name Type Required Description
save_dir str Yes Directory for saving the checkpoint
tag str No Unique identifier for checkpoint; defaults to "global_step{N}"
client_state dict No User-defined state dictionary (e.g., epoch, custom metrics)
save_latest bool No Write a "latest" file pointing to this checkpoint (default: True)
exclude_frozen_parameters bool No Exclude frozen (non-trainable) parameters from checkpoint (default: False)

Inputs (load_checkpoint)

Name Type Required Description
load_dir str Yes Directory to load the checkpoint from
tag str No Checkpoint tag; reads from "latest" file if not provided
load_module_strict bool No Strictly enforce key matching between module and checkpoint (default: True)
load_optimizer_states bool No Load optimizer states from checkpoint (default: True)
load_lr_scheduler_states bool No Load LR scheduler states from checkpoint (default: True)
load_module_only bool No Load only model weights, not optimizer or scheduler (default: False)
custom_load_fn Callable No Custom function for model loading

Inputs (get_fp32_state_dict_from_zero_checkpoint)

Name Type Required Description
checkpoint_dir str Yes Path to the checkpoint folder
tag str No Checkpoint tag; reads from "latest" file if not provided
exclude_frozen_parameters bool No Exclude frozen parameters from the consolidated state dict (default: False)
lazy_mode bool No Return pseudo tensors for memory-efficient loading; call .contiguous() to materialize (default: False)

Outputs

Name Type Description
(save_checkpoint) None Checkpoint files written to disk at save_dir/tag/
(load_checkpoint) Tuple[str, dict] Tuple of (load_path, client_state_dict); load_path is None if loading failed
(get_fp32_state_dict) dict Python dictionary with full fp32 model state, compatible with model.load_state_dict()

Usage Examples

import deepspeed
import torch.nn as nn

model = nn.Linear(1024, 10)
engine, optimizer, _, _ = deepspeed.initialize(
    model=model, config="ds_config.json",
    model_parameters=model.parameters()
)

# === Save checkpoint during training ===
# All ranks must call this (not just rank 0)
engine.save_checkpoint(
    "checkpoints/",
    tag="step_1000",
    client_state={"epoch": 5, "best_loss": 0.42}
)

# === Resume training from checkpoint ===
load_path, client_state = engine.load_checkpoint(
    "checkpoints/",
    tag="step_1000"
)
if load_path:
    print(f"Resumed from {load_path}")
    print(f"Epoch: {client_state['epoch']}")

# === Load latest checkpoint (tag resolved automatically) ===
load_path, client_state = engine.load_checkpoint("checkpoints/")

# === Warm-start from checkpoint (model only, no optimizer) ===
engine.load_checkpoint(
    "checkpoints/",
    tag="step_1000",
    load_module_only=True
)

# === Convert ZeRO checkpoint for deployment ===
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint

state_dict = get_fp32_state_dict_from_zero_checkpoint(
    "checkpoints/step_1000"
)
standalone_model = nn.Linear(1024, 10)
standalone_model.load_state_dict(state_dict)
# standalone_model is now a standard PyTorch model for inference

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment