Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:OpenRLHF OpenRLHF DeepSpeed Checkpoint Conversion

From Leeroopedia
Revision as of 18:21, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/OpenRLHF_OpenRLHF_DeepSpeed_Checkpoint_Conversion.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Checkpointing, DeepSpeed, Utilities
Last Updated 2026-02-07 10:40 GMT

Overview

Checkpoint portability technique that converts DeepSpeed ZeRO-sharded checkpoints to a universal format usable across different parallelism configurations.

Description

DeepSpeed ZeRO checkpoints are sharded across multiple GPUs according to the specific ZeRO stage and parallelism configuration used during training. These sharded checkpoints cannot be directly loaded in a different configuration (e.g., different number of GPUs or ZeRO stage). Universal checkpoints consolidate the sharded state into a portable format that can be loaded regardless of the target parallelism setup. This conversion is essential for model deployment, evaluation on different hardware, or resuming training with a different configuration.

Usage

Use checkpoint conversion after training with DeepSpeed ZeRO (stages 1-3) when you need to extract model weights for deployment, evaluation, or training resumption with a different GPU configuration. It is particularly important for PPO training where actor and critic models have separate checkpoint directories.

Theoretical Basis

ZeRO partitions model state across data-parallel ranks:

  • ZeRO Stage 1: Optimizer states are partitioned
  • ZeRO Stage 2: Optimizer states + gradients are partitioned
  • ZeRO Stage 3: Optimizer states + gradients + parameters are partitioned

The universal conversion reconstructs the full state by:

Pseudo-code Logic:

# Abstract conversion process (NOT actual implementation)
# 1. Read latest checkpoint tag
tag = read_file(checkpoint_path / "latest")

# 2. Gather sharded state from all ranks
for rank in range(num_ranks):
    shard = load(checkpoint_path / tag / f"rank_{rank}")
    gather_state(full_state, shard)

# 3. Write consolidated universal checkpoint
save(full_state, checkpoint_path / f"{tag}_uni")
write_file(checkpoint_path / "latest_universal", f"{tag}_uni")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment