Principle:Deepspeedai DeepSpeed RLHF Checkpointing
Overview
Saving and restoring RLHF training state including the actor model, optimizer, and PPO iteration metadata through the Hybrid Engine's checkpoint system.
Description
RLHF checkpointing must save not just the model and optimizer state, but also the PPO iteration count, KL penalty coefficient, and other RLHF-specific metadata via the client_state parameter. The Hybrid Engine inherits save_checkpoint() from DeepSpeedEngine, which handles ZeRO-partitioned state saving. The RLHF loop repeats the cycle of eval -> generate -> train -> backward -> step -> save.
Checkpointing is critical for RLHF training for several reasons:
- Long training durations: RLHF training can take days or weeks. Checkpoints enable recovery from hardware failures, preemptions, or other interruptions without losing progress.
- Multi-component state: Unlike standard fine-tuning where only model weights and optimizer states need saving, RLHF requires preserving additional state such as the PPO iteration counter, the adaptive KL coefficient (which may change during training), running statistics for reward normalization, and any experience replay buffer state.
- ZeRO-partitioned state: With ZeRO Stage 2 or 3, optimizer states and potentially model parameters are partitioned across data-parallel ranks. The checkpoint system must coordinate across all ranks to save a consistent snapshot, and all processes must participate in the save operation.
- Atomic consistency: The checkpoint must capture a consistent state where the model parameters, optimizer moments, learning rate scheduler position, and RLHF metadata all correspond to the same training iteration. The
client_statedictionary mechanism enables this by bundling RLHF-specific state with the engine checkpoint in a single atomic save operation.
The save_checkpoint() method creates a directory structure containing model weights (or ZeRO partitions), optimizer state, and a metadata file that includes the client_state dictionary. On resumption, load_checkpoint() restores all of this state, and the caller retrieves the client_state to restore RLHF-specific variables.
Theoretical Basis
RLHF training state that must be preserved for correct resumption includes:
- Actor model parameters (
theta): The current policy weights. - Optimizer state: Adam moments (first and second moment estimates) and step count, potentially partitioned across ranks by ZeRO.
- Learning rate scheduler state: Current learning rate and scheduler step count.
- PPO hyperparameters: KL coefficient
beta(which may be adaptively adjusted during training), clipping epsiloneps, and any scheduled hyperparameter values. - Iteration count: The current PPO iteration number for correct resumption of the training loop.
All of these must be saved atomically to ensure that upon resumption, the training can continue from exactly the same state. Saving only partial state (for example, model weights without the optimizer moments) would cause the optimizer to restart from scratch, degrading training quality.
Related Pages
Knowledge Sources
Last updated: 2026-02-09 00:00 GMT