Implementation:Microsoft DeepSpeedExamples DeepSpeedPPOTrainer
Overview
Concrete tool for executing PPO training with experience generation and policy optimization provided by the DeepSpeed-Chat library.
Description
DeepSpeedPPOTrainer implements the full PPO training loop for RLHF. It wraps the four models from a DeepSpeedRLHFEngine and provides methods for:
- Experience generation — The actor generates text, and all four models compute log-probabilities, values, and rewards.
- Policy optimization — Actor and critic losses are computed with clipped objectives, gradients are backpropagated, and optimizer steps are taken.
- Reward computation — Per-token rewards are computed by combining KL divergence penalties with the reward model score.
The trainer also handles invalid generations (sequences too short) by falling back to the last valid experience, and supports overflow alignment to synchronize gradient overflow between the actor and critic optimizers.
Code Reference
- File:
applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py - Lines: 44-371
Signature
class DeepSpeedPPOTrainer:
def __init__(self, rlhf_engine, args):
...
def generate_experience(self, prompts, mask, step):
"""Generate text and compute logprobs/values/rewards from all 4 models."""
...
def train_rlhf(self, inputs):
"""Compute actor and critic losses and perform optimization steps."""
...
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, action_mask):
"""Compute per-token rewards with KL penalty and clipped reward score."""
...
def get_advantages_and_returns(self, values, rewards, start):
"""Compute GAE advantages and returns."""
...
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
"""Clipped surrogate policy gradient loss."""
...
def critic_loss_fn(self, values, old_values, returns, mask):
"""Clipped value function loss."""
...
Import
from dschat.rlhf.ppo_trainer import DeepSpeedPPOTrainer
Inputs / Outputs
Constructor Inputs
| Parameter | Type | Description |
|---|---|---|
rlhf_engine |
DeepSpeedRLHFEngine |
Engine containing the four initialized models (.actor, .ref, .critic, .reward) |
args |
Namespace |
Configuration namespace containing max_answer_seq_len, end_of_conversation_token, actor_zero_stage, compute_fp32_loss, print_answers, align_overflow, and more
|
Key Method: generate_experience
| Parameter | Type | Description |
|---|---|---|
prompts |
torch.Tensor |
Tokenized prompt batch (batch_size x prompt_length) |
mask |
torch.Tensor |
Attention mask for prompts |
step |
int |
Current training step (used for logging) |
Returns a dictionary with keys:
| Key | Type | Description |
|---|---|---|
prompts |
torch.Tensor |
Original prompt tensor |
logprobs |
torch.Tensor |
Actor log-probabilities for generated tokens |
ref_logprobs |
torch.Tensor |
Reference model log-probabilities for generated tokens |
value |
torch.Tensor |
Critic value estimates per token |
rewards |
torch.Tensor |
Reward model scores per sequence |
input_ids |
torch.Tensor |
Full sequence (prompt + generated tokens) |
attention_mask |
torch.Tensor |
Attention mask for full sequence |
Key Method: train_rlhf
| Parameter | Type | Description |
|---|---|---|
inputs |
dict |
Experience dictionary returned by generate_experience()
|
Returns: (actor_loss, critic_loss) — a tuple of scalar tensors.
Internal Hyperparameters
These are set as instance attributes in the constructor:
| Attribute | Default | Description |
|---|---|---|
self.kl_ctl |
0.1 | KL penalty coefficient (beta) |
self.clip_reward_value |
5 | Maximum absolute reward value |
self.cliprange |
0.2 | Actor loss clipping parameter (epsilon) |
self.cliprange_value |
0.2 | Critic value loss clipping parameter |
self.gamma |
1.0 | Discount factor |
self.lam |
0.95 | GAE lambda parameter |
Training Flow
The following sequence describes a single PPO training iteration:
# 1. Generate experience (all models in eval mode)
exp = trainer.generate_experience(prompts, mask, step)
# 2. Split experience into mini-batches
mini_dataset = MiniDataset(max_mini_dataset_size, mini_batch_size)
for mini_batch in mini_dataset.add(exp):
# 3. Compute losses and update actor + critic
actor_loss, critic_loss = trainer.train_rlhf(mini_batch)
Detailed Breakdown
generate_experience()switches all models to eval mode, generates sequences viaactor.module.generate(), then computes:- Actor log-probabilities via forward pass on actor
- Reference log-probabilities via forward pass on reference
- Reward scores via
reward_model.forward_value() - Value estimates via
critic_model.forward_value()
train_rlhf()switches actor and critic to train mode and:- Computes per-token rewards with KL penalty via
compute_rewards() - Computes GAE advantages and returns via
get_advantages_and_returns() - Runs actor forward pass to get updated log-probabilities
- Computes clipped surrogate actor loss via
actor_loss_fn() - Calls
actor_model.backward()andactor_model.step() - Runs critic forward pass to get updated value estimates
- Computes clipped value critic loss via
critic_loss_fn() - Calls
critic_model.backward()andcritic_model.step()
- Computes per-token rewards with KL penalty via
- Optional overflow alignment synchronizes skip decisions between actor and critic when one experiences gradient overflow.
Subclass: DeepSpeedPPOTrainerUnsupervised
A subclass DeepSpeedPPOTrainerUnsupervised (lines 357-371) extends the base trainer with a train_unsupervised() method that performs an additional language modeling loss on unsupervised data, scaled by a coefficient unsup_coef.
class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer):
def train_unsupervised(self, inputs, unsup_coef):
outputs = self.actor_model(**inputs, use_cache=False)
loss = outputs.loss
self.actor_model.backward(unsup_coef * loss)
self.actor_model.step()
return loss
Related
- Principle:Microsoft_DeepSpeedExamples_PPO_Training
- Environment:Microsoft_DeepSpeedExamples_RLHF_Training_Environment
- Heuristic:Microsoft_DeepSpeedExamples_RLHF_Stability_Constraints
- Heuristic:Microsoft_DeepSpeedExamples_RLHF_Hyperparameter_Guide
- Heuristic:Microsoft_DeepSpeedExamples_Gradient_Checkpointing_Tradeoff