Implementation:Microsoft DeepSpeedExamples DeepSpeedRLHFEngine
Appearance
Overview
Concrete tool for initializing and managing the four-model RLHF architecture provided by the DeepSpeed-Chat library.
Description
DeepSpeedRLHFEngine is the central orchestration class that creates and wraps four models using deepspeed.initialize():
- Actor — Initialized from an SFT checkpoint with full optimizer (FusedAdam or DeepSpeedCPUAdam), learning rate scheduler, optional LoRA adaptation, and optional hybrid engine for accelerated inference. Uses the actor-specific ZeRO stage.
- Reference — Initialized from the same SFT checkpoint as the actor but with an eval-only DeepSpeed configuration (no optimizer). Uses ZeRO-3 if the actor uses ZeRO-3, otherwise defaults to ZeRO-0.
- Critic — Initialized from a reward model checkpoint using
create_critic_model(). Equipped with its own optimizer and learning rate scheduler. Supports LoRA and gradient checkpointing. Uses the critic-specific ZeRO stage. - Reward — Initialized from the same reward model checkpoint as the critic but with an eval-only configuration (no optimizer). Uses ZeRO-3 if the critic uses ZeRO-3, otherwise defaults to ZeRO-0.
An optional actor EMA (Exponential Moving Average) model can also be initialized when args.enable_ema is set to True.
Code Reference
- File:
applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py - Lines: 40-296
Signature
class DeepSpeedRLHFEngine:
def __init__(
self,
actor_model_name_or_path,
critic_model_name_or_path,
tokenizer,
args,
num_total_iters
):
...
Import
from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine
Inputs / Outputs
Inputs
| Parameter | Type | Description |
|---|---|---|
actor_model_name_or_path |
str |
Path to the SFT checkpoint or HuggingFace model name for the actor and reference models |
critic_model_name_or_path |
str |
Path to the reward model checkpoint for the critic and reward models |
tokenizer |
PreTrainedTokenizer |
HuggingFace tokenizer instance shared by all models |
args |
Namespace |
Configuration namespace containing ZeRO stages, LoRA settings, offloading flags, learning rates, dropout, hybrid engine settings, and more |
num_total_iters |
int |
Total number of training iterations (used for learning rate scheduler) |
Outputs
The engine exposes the following attributes after initialization:
| Attribute | Type | Description |
|---|---|---|
.actor |
DeepSpeedEngine |
DeepSpeed-wrapped actor model with optimizer and LR scheduler |
.ref |
DeepSpeedEngine |
DeepSpeed-wrapped reference model (frozen, no optimizer) |
.critic |
DeepSpeedEngine |
DeepSpeed-wrapped critic model with optimizer and LR scheduler |
.reward |
DeepSpeedEngine |
DeepSpeed-wrapped reward model (frozen, no optimizer) |
.actor_ema |
DeepSpeedEngine or None |
Optional EMA copy of the actor (enabled via args.enable_ema)
|
Key Configuration Parameters (via args)
| Parameter | Effect |
|---|---|
actor_zero_stage |
ZeRO stage for the actor model (0, 1, 2, or 3) |
critic_zero_stage |
ZeRO stage for the critic model (0, 1, 2, or 3) |
actor_lora_dim |
LoRA rank for actor; set >0 to enable LoRA |
critic_lora_dim |
LoRA rank for critic; set >0 to enable LoRA |
offload |
Enable CPU offloading for optimizer states |
offload_reference_model |
Enable CPU offloading for the reference model |
offload_reward_model |
Enable CPU offloading for the reward model |
enable_hybrid_engine |
Enable DeepSpeed hybrid engine for actor inference |
enable_ema |
Enable Exponential Moving Average of actor weights |
actor_learning_rate |
Learning rate for the actor optimizer |
critic_learning_rate |
Learning rate for the critic optimizer |
Internal Initialization Methods
| Method | Model | Config Type | Optimizer | ZeRO Stage Logic |
|---|---|---|---|---|
_init_actor() |
Actor | Training | FusedAdam / CPUAdam | args.actor_zero_stage
|
_init_ref() |
Reference | Eval-only | None | ZeRO-3 if actor is ZeRO-3, else ZeRO-0 |
_init_ema() |
Actor EMA | Eval-only | None | ZeRO-3 if actor is ZeRO-3, else ZeRO-0 |
_init_critic() |
Critic | Training | FusedAdam / CPUAdam | args.critic_zero_stage
|
_init_reward() |
Reward | Eval-only | None | ZeRO-3 if critic is ZeRO-3, else ZeRO-0 |
Example Usage
from dschat.rlhf.rlhf_engine import DeepSpeedRLHFEngine
engine = DeepSpeedRLHFEngine(
actor_model_name_or_path="/path/to/sft_checkpoint",
critic_model_name_or_path="/path/to/reward_model_checkpoint",
tokenizer=tokenizer,
args=args,
num_total_iters=1000
)
# Access the four wrapped models
actor = engine.actor # trainable actor
ref = engine.ref # frozen reference
critic = engine.critic # trainable critic
reward = engine.reward # frozen reward
Related
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment