Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples DeepSpeedRLHFEngine

From Leeroopedia


Overview

Concrete tool for initializing and managing the four-model RLHF architecture provided by the DeepSpeed-Chat library.

Description

DeepSpeedRLHFEngine is the central orchestration class that creates and wraps four models using deepspeed.initialize():

  • Actor — Initialized from an SFT checkpoint with full optimizer (FusedAdam or DeepSpeedCPUAdam), learning rate scheduler, optional LoRA adaptation, and optional hybrid engine for accelerated inference. Uses the actor-specific ZeRO stage.
  • Reference — Initialized from the same SFT checkpoint as the actor but with an eval-only DeepSpeed configuration (no optimizer). Uses ZeRO-3 if the actor uses ZeRO-3, otherwise defaults to ZeRO-0.
  • Critic — Initialized from a reward model checkpoint using create_critic_model(). Equipped with its own optimizer and learning rate scheduler. Supports LoRA and gradient checkpointing. Uses the critic-specific ZeRO stage.
  • Reward — Initialized from the same reward model checkpoint as the critic but with an eval-only configuration (no optimizer). Uses ZeRO-3 if the critic uses ZeRO-3, otherwise defaults to ZeRO-0.

An optional actor EMA (Exponential Moving Average) model can also be initialized when args.enable_ema is set to True.

Code Reference

  • File: applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py
  • Lines: 40-296

Signature

class DeepSpeedRLHFEngine:
    def __init__(
        self,
        actor_model_name_or_path,
        critic_model_name_or_path,
        tokenizer,
        args,
        num_total_iters
    ):
        ...

Import

from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine

Inputs / Outputs

Inputs

Parameter Type Description
actor_model_name_or_path str Path to the SFT checkpoint or HuggingFace model name for the actor and reference models
critic_model_name_or_path str Path to the reward model checkpoint for the critic and reward models
tokenizer PreTrainedTokenizer HuggingFace tokenizer instance shared by all models
args Namespace Configuration namespace containing ZeRO stages, LoRA settings, offloading flags, learning rates, dropout, hybrid engine settings, and more
num_total_iters int Total number of training iterations (used for learning rate scheduler)

Outputs

The engine exposes the following attributes after initialization:

Attribute Type Description
.actor DeepSpeedEngine DeepSpeed-wrapped actor model with optimizer and LR scheduler
.ref DeepSpeedEngine DeepSpeed-wrapped reference model (frozen, no optimizer)
.critic DeepSpeedEngine DeepSpeed-wrapped critic model with optimizer and LR scheduler
.reward DeepSpeedEngine DeepSpeed-wrapped reward model (frozen, no optimizer)
.actor_ema DeepSpeedEngine or None Optional EMA copy of the actor (enabled via args.enable_ema)

Key Configuration Parameters (via args)

Parameter Effect
actor_zero_stage ZeRO stage for the actor model (0, 1, 2, or 3)
critic_zero_stage ZeRO stage for the critic model (0, 1, 2, or 3)
actor_lora_dim LoRA rank for actor; set >0 to enable LoRA
critic_lora_dim LoRA rank for critic; set >0 to enable LoRA
offload Enable CPU offloading for optimizer states
offload_reference_model Enable CPU offloading for the reference model
offload_reward_model Enable CPU offloading for the reward model
enable_hybrid_engine Enable DeepSpeed hybrid engine for actor inference
enable_ema Enable Exponential Moving Average of actor weights
actor_learning_rate Learning rate for the actor optimizer
critic_learning_rate Learning rate for the critic optimizer

Internal Initialization Methods

Method Model Config Type Optimizer ZeRO Stage Logic
_init_actor() Actor Training FusedAdam / CPUAdam args.actor_zero_stage
_init_ref() Reference Eval-only None ZeRO-3 if actor is ZeRO-3, else ZeRO-0
_init_ema() Actor EMA Eval-only None ZeRO-3 if actor is ZeRO-3, else ZeRO-0
_init_critic() Critic Training FusedAdam / CPUAdam args.critic_zero_stage
_init_reward() Reward Eval-only None ZeRO-3 if critic is ZeRO-3, else ZeRO-0

Example Usage

from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine

engine = DeepSpeedRLHFEngine(
    actor_model_name_or_path="/path/to/sft_checkpoint",
    critic_model_name_or_path="/path/to/reward_model_checkpoint",
    tokenizer=tokenizer,
    args=args,
    num_total_iters=1000
)

# Access the four wrapped models
actor = engine.actor        # trainable actor
ref = engine.ref            # frozen reference
critic = engine.critic      # trainable critic
reward = engine.reward      # frozen reward

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment