Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Save For RLHF

From Leeroopedia
Revision as of 14:46, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Deepspeedai_DeepSpeed_DeepSpeedEngine_Save_For_RLHF.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

Concrete tool for saving RLHF training checkpoints through the Hybrid Engine provided by the DeepSpeed library.

Description

DeepSpeedHybridEngine.save_checkpoint() is inherited from DeepSpeedEngine. It saves the actor model state, optimizer state, and LR scheduler state. RLHF-specific state (PPO iteration, KL coefficient) is passed via the client_state parameter and saved alongside the engine state.

The save_checkpoint() method (L3695-3789 in engine.py) performs the following operations:

  1. Directory creation: Rank 0 creates the save directory, followed by a distributed barrier to ensure all ranks see the directory.
  2. Tag assignment: If no tag is provided, the global step count is used as the checkpoint tag (for example, global_step42).
  3. Tag validation: Ensures the checkpoint tag is consistent across all ranks to prevent mismatched checkpoint states.
  4. Checkpoint file creation: Creates the checkpoint directory structure under save_dir/tag/.
  5. State saving: Saves the model state dict, optimizer state dict, LR scheduler state, and client_state dictionary. For ZeRO Stage 2/3, optimizer states are saved in their partitioned form to avoid memory spikes from gathering.
  6. ZeRO checkpoint saving: If ZeRO is active, saves additional ZeRO-specific checkpoint files containing partitioned optimizer states.
  7. NVMe offload handling: If NVMe offloading is used, copies the offloaded tensor files to the checkpoint directory.
  8. Latest pointer: Writes a latest file pointing to the most recent checkpoint tag.
  9. Barrier synchronization: All ranks synchronize after saving to ensure consistency.

Important: All processes must call save_checkpoint(), not just rank 0. This is because each process holds its own partition of the optimizer state (with ZeRO) and must save its portion independently.

Code Reference

Property Value
Repository https://github.com/deepspeedai/DeepSpeed
File deepspeed/runtime/engine.py (L3695-3789)
Signature def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False) -> bool
Import Accessed via engine returned by deepspeed.initialize()

I/O Contract

Inputs

Name Type Required Description
save_dir str Yes Directory path for saving the checkpoint
tag str No Unique identifier for the checkpoint; defaults to global_step{N}
client_state dict No RLHF-specific state (iteration count, KL coefficient, etc.)
save_latest bool No Write a latest file pointing to this checkpoint (default True)
exclude_frozen_parameters bool No Exclude frozen parameters from the saved state (default False)

Outputs

Name Type Description
success bool Returns True upon successful checkpoint save
(side effect) files on disk Checkpoint directory with model, optimizer, and client state files

Usage Example

for rlhf_iter in range(num_rlhf_iterations):
    # Step 4: Generate experience
    engine.eval()
    sequences = engine.generate(input_ids=prompts, max_new_tokens=256)

    # Step 5: PPO update
    engine.train()
    ppo_loss = compute_ppo_loss(engine, sequences, rewards)
    engine.backward(ppo_loss)
    engine.step()

    # Step 6: Checkpoint
    engine.save_checkpoint(
        "rlhf_checkpoints/",
        tag=f"iter_{rlhf_iter}",
        client_state={
            "rlhf_iteration": rlhf_iter,
            "kl_coefficient": kl_coeff,
        }
    )

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment