Overview
Description
The PPOTrainer.train method implements the complete PPO training loop for RLHF. Each iteration performs: (1) rollout generation via the policy model, (2) KL-penalized reward computation, (3) GAE advantage estimation, and (4) multiple PPO optimization epochs with clipped policy and value losses. The method manages memory aggressively by clearing caches between phases.
Usage
Called on an initialized PPOTrainer instance. The method runs for num_total_batches iterations, logging metrics and optionally generating sample completions at regular intervals.
Code Reference
Source Location
trl/experimental/ppo/ppo_trainer.py lines 596-940
Signature
def train(self) -> None:
"""
Main PPO training loop.
Phase structure per iteration:
- Lines 664-745: Rollout generation (response generation, logprob computation,
reward scoring, value estimation)
- Lines 765-771: KL-penalized reward computation
- Lines 779-791: GAE advantage estimation
- Lines 794-872: PPO optimization epochs (clipped policy + value loss)
- Lines 873-910: Metrics logging and checkpoint saving
"""
Key Internal Functions
def generate(lm_backbone, queries, pad_token_id, generation_config):
"""Generate responses from the policy model."""
# Returns (query_responses, logits)
def batch_generation(model, queries, batch_size, pad_token_id, generation_config):
"""Generate in sub-batches to manage memory."""
# Returns (padded_query_responses, padded_logitss)
def forward(model, query_responses, pad_token_id):
"""Forward pass with proper attention mask and position IDs."""
# Returns ModelOutput with hidden states
def truncate_response(stop_token_id, pad_token_id, responses):
"""Truncate responses at stop token, fill rest with pad."""
# Returns truncated response tensor
Import
# train() is a method on PPOTrainer, invoked as:
from trl.experimental.ppo import PPOTrainer
trainer = PPOTrainer(...)
trainer.train()
I/O Contract
Phase 1: Rollout Generation
| Input |
Type |
Description
|
| queries |
torch.Tensor |
Batch of tokenized prompts from the dataloader
|
| generation_config |
GenerationConfig |
Sampling config (temperature, max_new_tokens, top_k=0, top_p=1)
|
| Output |
Type |
Shape |
Description
|
| query_responses |
torch.Tensor |
(B, context_len + response_len) |
Concatenated prompt + generated response
|
| responses |
torch.Tensor |
(B, response_len) |
Generated response tokens only
|
| logprobs |
torch.Tensor |
(B, response_len) |
Log-probabilities under current policy
|
| ref_logprobs |
torch.Tensor |
(B, response_len) |
Log-probabilities under reference policy
|
| scores |
torch.Tensor |
(B,) |
Reward model scores for postprocessed responses
|
| values |
torch.Tensor |
(B, response_len) |
Value model estimates at each token position
|
Phase 2: KL-Penalized Reward
| Computation |
Formula
|
| KL divergence (k1) |
kl = -(ref_logprobs - logprobs)
|
| KL divergence (k3) |
logr = ref_logprobs - logprobs; kl = (exp(logr) - 1) - logr
|
| Non-score reward |
non_score_reward = -kl_coef * kl
|
| Final reward |
rewards = non_score_reward; rewards[last_token] += score
|
Phase 3: GAE
| Computation |
Formula
|
| TD error |
delta[t] = rewards[t] + gamma * values[t+1] - values[t]
|
| GAE advantage |
A[t] = delta[t] + gamma * lam * A[t+1]
|
| Returns |
returns = advantages + values
|
| Whitened advantages |
advantages = whiten(advantages, mask=~padding_mask)
|
Phase 4: PPO Optimization
| Loss Component |
Formula
|
| Importance ratio |
ratio = exp(new_logprob - old_logprob)
|
| Clipped policy loss |
max(-adv * ratio, -adv * clip(ratio, 1-eps, 1+eps))
|
| Clipped value loss |
0.5 * max((vpred - ret)^2, (clip(vpred) - ret)^2)
|
| Total loss |
policy_loss + vf_coef * value_loss
|
Logged Metrics
| Metric Key |
Description
|
| objective/kl |
Mean KL divergence from reference
|
| objective/entropy |
Mean policy entropy
|
| objective/non_score_reward |
Mean KL penalty
|
| objective/rlhf_reward |
Total reward (non_score + score)
|
| objective/scores |
Mean reward model scores
|
| policy/approxkl_avg |
Approximate KL between old and new policy
|
| policy/clipfrac_avg |
Fraction of clipped policy ratios
|
| loss/policy_avg |
Mean policy loss
|
| loss/value_avg |
Mean value loss
|
| val/clipfrac_avg |
Fraction of clipped value predictions
|
| val/ratio |
Mean importance sampling ratio
|
| val/ratio_var |
Variance of importance sampling ratio
|
| val/num_eos_tokens |
Count of generated EOS tokens
|
| lr |
Current learning rate
|
| episode |
Current episode count
|
Usage Examples
Basic Training
from trl.experimental.ppo import PPOTrainer, PPOConfig
# Assuming trainer is already initialized
trainer = PPOTrainer(
args=config,
processing_class=tokenizer,
model=policy,
ref_model=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Run the full PPO training loop
trainer.train()
# Save the trained policy
trainer.save_model("ppo-trained-policy")
With DeepSpeed ZeRO-3
accelerate launch --config_file deepspeed_zero3.yaml \
examples/scripts/ppo/ppo.py \
--num_ppo_epochs 1 \
--num_mini_batches 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--total_episodes 10000 \
--local_rollout_forward_batch_size 1
Related Pages