Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples DeepSpeedPPOTrainer

From Leeroopedia


Overview

Concrete tool for executing PPO training with experience generation and policy optimization provided by the DeepSpeed-Chat library.

Description

DeepSpeedPPOTrainer implements the full PPO training loop for RLHF. It wraps the four models from a DeepSpeedRLHFEngine and provides methods for:

  • Experience generation — The actor generates text, and all four models compute log-probabilities, values, and rewards.
  • Policy optimization — Actor and critic losses are computed with clipped objectives, gradients are backpropagated, and optimizer steps are taken.
  • Reward computation — Per-token rewards are computed by combining KL divergence penalties with the reward model score.

The trainer also handles invalid generations (sequences too short) by falling back to the last valid experience, and supports overflow alignment to synchronize gradient overflow between the actor and critic optimizers.

Code Reference

  • File: applications/DeepSpeed-Chat/dschat/rlhf/ppo_trainer.py
  • Lines: 44-371

Signature

class DeepSpeedPPOTrainer:
    def __init__(self, rlhf_engine, args):
        ...

    def generate_experience(self, prompts, mask, step):
        """Generate text and compute logprobs/values/rewards from all 4 models."""
        ...

    def train_rlhf(self, inputs):
        """Compute actor and critic losses and perform optimization steps."""
        ...

    def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, action_mask):
        """Compute per-token rewards with KL penalty and clipped reward score."""
        ...

    def get_advantages_and_returns(self, values, rewards, start):
        """Compute GAE advantages and returns."""
        ...

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        """Clipped surrogate policy gradient loss."""
        ...

    def critic_loss_fn(self, values, old_values, returns, mask):
        """Clipped value function loss."""
        ...

Import

from dschat.rlhf.ppo_trainer import DeepSpeedPPOTrainer

Inputs / Outputs

Constructor Inputs

Parameter Type Description
rlhf_engine DeepSpeedRLHFEngine Engine containing the four initialized models (.actor, .ref, .critic, .reward)
args Namespace Configuration namespace containing max_answer_seq_len, end_of_conversation_token, actor_zero_stage, compute_fp32_loss, print_answers, align_overflow, and more

Key Method: generate_experience

Parameter Type Description
prompts torch.Tensor Tokenized prompt batch (batch_size x prompt_length)
mask torch.Tensor Attention mask for prompts
step int Current training step (used for logging)

Returns a dictionary with keys:

Key Type Description
prompts torch.Tensor Original prompt tensor
logprobs torch.Tensor Actor log-probabilities for generated tokens
ref_logprobs torch.Tensor Reference model log-probabilities for generated tokens
value torch.Tensor Critic value estimates per token
rewards torch.Tensor Reward model scores per sequence
input_ids torch.Tensor Full sequence (prompt + generated tokens)
attention_mask torch.Tensor Attention mask for full sequence

Key Method: train_rlhf

Parameter Type Description
inputs dict Experience dictionary returned by generate_experience()

Returns: (actor_loss, critic_loss) — a tuple of scalar tensors.

Internal Hyperparameters

These are set as instance attributes in the constructor:

Attribute Default Description
self.kl_ctl 0.1 KL penalty coefficient (beta)
self.clip_reward_value 5 Maximum absolute reward value
self.cliprange 0.2 Actor loss clipping parameter (epsilon)
self.cliprange_value 0.2 Critic value loss clipping parameter
self.gamma 1.0 Discount factor
self.lam 0.95 GAE lambda parameter

Training Flow

The following sequence describes a single PPO training iteration:

# 1. Generate experience (all models in eval mode)
exp = trainer.generate_experience(prompts, mask, step)

# 2. Split experience into mini-batches
mini_dataset = MiniDataset(max_mini_dataset_size, mini_batch_size)
for mini_batch in mini_dataset.add(exp):
    # 3. Compute losses and update actor + critic
    actor_loss, critic_loss = trainer.train_rlhf(mini_batch)

Detailed Breakdown

  1. generate_experience() switches all models to eval mode, generates sequences via actor.module.generate(), then computes:
    • Actor log-probabilities via forward pass on actor
    • Reference log-probabilities via forward pass on reference
    • Reward scores via reward_model.forward_value()
    • Value estimates via critic_model.forward_value()
  2. train_rlhf() switches actor and critic to train mode and:
    • Computes per-token rewards with KL penalty via compute_rewards()
    • Computes GAE advantages and returns via get_advantages_and_returns()
    • Runs actor forward pass to get updated log-probabilities
    • Computes clipped surrogate actor loss via actor_loss_fn()
    • Calls actor_model.backward() and actor_model.step()
    • Runs critic forward pass to get updated value estimates
    • Computes clipped value critic loss via critic_loss_fn()
    • Calls critic_model.backward() and critic_model.step()
  3. Optional overflow alignment synchronizes skip decisions between actor and critic when one experiences gradient overflow.

Subclass: DeepSpeedPPOTrainerUnsupervised

A subclass DeepSpeedPPOTrainerUnsupervised (lines 357-371) extends the base trainer with a train_unsupervised() method that performs an additional language modeling loss on unsupervised data, scaled by a coefficient unsup_coef.

class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer):
    def train_unsupervised(self, inputs, unsup_coef):
        outputs = self.actor_model(**inputs, use_cache=False)
        loss = outputs.loss
        self.actor_model.backward(unsup_coef * loss)
        self.actor_model.step()
        return loss

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment