Implementation:Microsoft DeepSpeedExamples Create Critic Model
Overview
Concrete tool for creating reward/critic models with a linear value head provided by the DeepSpeed-Chat library.
Description
create_critic_model wraps a base language model with a RewardModel class that adds a v_head linear layer (hidden_size -> 1) for scalar reward prediction. It handles OPT model padding quirks, loads from reward checkpoints when rlhf_training=True, and supports ZeRO-3 parameter gathering.
The function operates in two modes:
- Training mode (
rlhf_training=False): Creates a freshRewardModelby loading a pre-trained base model viacreate_hf_modelwithAutoModel, then wrapping it with theRewardModelclass. The value head (v_head) is initialized with random weights. This mode is used in Step 2 (Reward Model Training) to train the reward model from scratch on human preference data.
- Checkpoint loading mode (
rlhf_training=True): Creates the model architecture without loading pre-trained weights (usingno_init_weights), then loads the full state dictionary from a previously saved reward model checkpoint (pytorch_model.bin). This mode is used in Step 3 (RLHF with PPO) to load a trained reward model for the critic and reward components of the RLHF engine.
The underlying RewardModel class:
- Adds a
v_head = nn.Linear(hidden_size, 1, bias=False)on top of the base transformer. - Handles OPT models specially by using
word_embed_proj_diminstead ofhidden_size. - Implements a
forwardmethod for training that computes pairwise ranking loss. - Implements a
forward_valuemethod for inference that returns per-token value estimates.
Usage
Use in Step 2 (Reward Model Training) and Step 3 (RLHF Engine) where critic and reward models are needed.
- In Step 2, call with
rlhf_training=Falseto create a fresh reward model for training on preference data. - In Step 3, call with
rlhf_training=Trueto load a trained checkpoint for both the critic model (which is further fine-tuned) and the frozen reward model (which provides the reward signal).
Code Reference
Source: Repository: DeepSpeedExamples, File: applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py
Signature:
def create_critic_model(
model_name_or_path,
tokenizer,
ds_config,
num_padding_at_beginning=0,
rlhf_training=False,
dropout=None,
zero_stage=0,
compute_fp32_loss=False
) -> RewardModel:
Import:
from dschat.utils.model.model_utils import create_critic_model
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model_name_or_path |
str | Yes | Base model name/path (HuggingFace hub ID or local path) or reward checkpoint path |
tokenizer |
AutoTokenizer | Yes | Tokenizer instance matching the base model |
ds_config |
dict | Yes | DeepSpeed configuration dictionary (controls ZeRO stage, offloading, etc.) |
num_padding_at_beginning |
int | No | Number of padding tokens at the beginning of the sequence (default: 0). Set to 1 for OPT model family. |
rlhf_training |
bool | No | If True, loads weights from a previously saved reward model checkpoint (default: False) |
dropout |
float | No | Override dropout rate for the model (default: None, uses model config) |
zero_stage |
int | No | ZeRO optimization stage for checkpoint loading compatibility (default: 0) |
compute_fp32_loss |
bool | No | Whether to compute the ranking loss in FP32 precision (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
model |
RewardModel | Base transformer model wrapped with a v_head linear layer (hidden_size -> 1) for scalar reward prediction
|
Usage Examples
Example 1: Creating a Reward Model for Training (Step 2)
from transformers import AutoTokenizer
from dschat.utils.model.model_utils import create_critic_model
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
ds_config = {
"zero_optimization": {
"stage": 0
},
"train_micro_batch_size_per_gpu": 4,
"train_batch_size": 16,
}
# Create a fresh reward model for training on preference data
reward_model = create_critic_model(
model_name_or_path="facebook/opt-350m",
tokenizer=tokenizer,
ds_config=ds_config,
num_padding_at_beginning=1, # OPT models pad the first token
rlhf_training=False,
compute_fp32_loss=False,
)
# reward_model is a RewardModel with a randomly initialized v_head
# Train it on chosen/rejected pairs using the pairwise ranking loss
Example 2: Loading a Reward Checkpoint for RLHF Engine (Step 3)
from transformers import AutoTokenizer
from dschat.utils.model.model_utils import create_critic_model
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
ds_config = {
"zero_optimization": {
"stage": 0
},
"train_micro_batch_size_per_gpu": 4,
"train_batch_size": 16,
}
# Load a previously trained reward model checkpoint for the RLHF engine
critic_model = create_critic_model(
model_name_or_path="/path/to/step2_reward_model_checkpoint",
tokenizer=tokenizer,
ds_config=ds_config,
num_padding_at_beginning=1,
rlhf_training=True, # Load from checkpoint
zero_stage=0, # Match the ZeRO stage used during saving
)
# critic_model has its v_head and transformer weights loaded from the checkpoint
# Ready to be used as the critic or reward model in DeepSpeedRLHFEngine