Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Deepspeedai DeepSpeed RLHF Checkpointing

From Leeroopedia


Overview

Saving and restoring RLHF training state including the actor model, optimizer, and PPO iteration metadata through the Hybrid Engine's checkpoint system.

Description

RLHF checkpointing must save not just the model and optimizer state, but also the PPO iteration count, KL penalty coefficient, and other RLHF-specific metadata via the client_state parameter. The Hybrid Engine inherits save_checkpoint() from DeepSpeedEngine, which handles ZeRO-partitioned state saving. The RLHF loop repeats the cycle of eval -> generate -> train -> backward -> step -> save.

Checkpointing is critical for RLHF training for several reasons:

  • Long training durations: RLHF training can take days or weeks. Checkpoints enable recovery from hardware failures, preemptions, or other interruptions without losing progress.
  • Multi-component state: Unlike standard fine-tuning where only model weights and optimizer states need saving, RLHF requires preserving additional state such as the PPO iteration counter, the adaptive KL coefficient (which may change during training), running statistics for reward normalization, and any experience replay buffer state.
  • ZeRO-partitioned state: With ZeRO Stage 2 or 3, optimizer states and potentially model parameters are partitioned across data-parallel ranks. The checkpoint system must coordinate across all ranks to save a consistent snapshot, and all processes must participate in the save operation.
  • Atomic consistency: The checkpoint must capture a consistent state where the model parameters, optimizer moments, learning rate scheduler position, and RLHF metadata all correspond to the same training iteration. The client_state dictionary mechanism enables this by bundling RLHF-specific state with the engine checkpoint in a single atomic save operation.

The save_checkpoint() method creates a directory structure containing model weights (or ZeRO partitions), optimizer state, and a metadata file that includes the client_state dictionary. On resumption, load_checkpoint() restores all of this state, and the caller retrieves the client_state to restore RLHF-specific variables.

Theoretical Basis

RLHF training state that must be preserved for correct resumption includes:

  • Actor model parameters (theta): The current policy weights.
  • Optimizer state: Adam moments (first and second moment estimates) and step count, potentially partitioned across ranks by ZeRO.
  • Learning rate scheduler state: Current learning rate and scheduler step count.
  • PPO hyperparameters: KL coefficient beta (which may be adaptively adjusted during training), clipping epsilon eps, and any scheduled hyperparameter values.
  • Iteration count: The current PPO iteration number for correct resumption of the training loop.

All of these must be saved atomically to ensure that upon resumption, the training can continue from exactly the same state. Saving only partial state (for example, model weights without the optimizer moments) would cause the optimizer to restart from scratch, degrading training quality.

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment