Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Save For RLHF
Overview
Concrete tool for saving RLHF training checkpoints through the Hybrid Engine provided by the DeepSpeed library.
Description
DeepSpeedHybridEngine.save_checkpoint() is inherited from DeepSpeedEngine. It saves the actor model state, optimizer state, and LR scheduler state. RLHF-specific state (PPO iteration, KL coefficient) is passed via the client_state parameter and saved alongside the engine state.
The save_checkpoint() method (L3695-3789 in engine.py) performs the following operations:
- Directory creation: Rank 0 creates the save directory, followed by a distributed barrier to ensure all ranks see the directory.
- Tag assignment: If no tag is provided, the global step count is used as the checkpoint tag (for example,
global_step42). - Tag validation: Ensures the checkpoint tag is consistent across all ranks to prevent mismatched checkpoint states.
- Checkpoint file creation: Creates the checkpoint directory structure under
save_dir/tag/. - State saving: Saves the model state dict, optimizer state dict, LR scheduler state, and
client_statedictionary. For ZeRO Stage 2/3, optimizer states are saved in their partitioned form to avoid memory spikes from gathering. - ZeRO checkpoint saving: If ZeRO is active, saves additional ZeRO-specific checkpoint files containing partitioned optimizer states.
- NVMe offload handling: If NVMe offloading is used, copies the offloaded tensor files to the checkpoint directory.
- Latest pointer: Writes a
latestfile pointing to the most recent checkpoint tag. - Barrier synchronization: All ranks synchronize after saving to ensure consistency.
Important: All processes must call save_checkpoint(), not just rank 0. This is because each process holds its own partition of the optimizer state (with ZeRO) and must save its portion independently.
Code Reference
| Property | Value |
|---|---|
| Repository | https://github.com/deepspeedai/DeepSpeed |
| File | deepspeed/runtime/engine.py (L3695-3789)
|
| Signature | def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False) -> bool
|
| Import | Accessed via engine returned by deepspeed.initialize()
|
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| save_dir | str | Yes | Directory path for saving the checkpoint |
| tag | str | No | Unique identifier for the checkpoint; defaults to global_step{N}
|
| client_state | dict | No | RLHF-specific state (iteration count, KL coefficient, etc.) |
| save_latest | bool | No | Write a latest file pointing to this checkpoint (default True)
|
| exclude_frozen_parameters | bool | No | Exclude frozen parameters from the saved state (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| success | bool | Returns True upon successful checkpoint save |
| (side effect) | files on disk | Checkpoint directory with model, optimizer, and client state files |
Usage Example
for rlhf_iter in range(num_rlhf_iterations):
# Step 4: Generate experience
engine.eval()
sequences = engine.generate(input_ids=prompts, max_new_tokens=256)
# Step 5: PPO update
engine.train()
ppo_loss = compute_ppo_loss(engine, sequences, rewards)
engine.backward(ppo_loss)
engine.step()
# Step 6: Checkpoint
engine.save_checkpoint(
"rlhf_checkpoints/",
tag=f"iter_{rlhf_iter}",
client_state={
"rlhf_iteration": rlhf_iter,
"kl_coefficient": kl_coeff,
}
)
Related Pages
Knowledge Sources
Last updated: 2026-02-09 00:00 GMT