Principle:Huggingface Trl PPO Training Loop
| Property | Value |
|---|---|
| Principle Name | PPO Training Loop |
| Technology | Huggingface TRL |
| Category | Training Algorithm |
| Workflow | PPO RLHF Training |
| Paper | PPO (https://arxiv.org/abs/1707.06347), GAE (https://arxiv.org/abs/1506.02438) |
| Implementation | Implementation:Huggingface_Trl_PPOTrainer_Train |
Overview
Description
The PPO training loop implements the complete Proximal Policy Optimization algorithm for RLHF. Each iteration consists of four phases: (1) rollout generation where the policy produces responses, (2) reward computation with KL penalties, (3) advantage estimation using GAE, and (4) multiple epochs of policy and value function optimization with clipped objectives.
This is the most computationally intensive part of the RLHF pipeline, requiring careful memory management with explicit cache clearing between phases.
Usage
The training loop is invoked by calling trainer.train() on an initialized PPOTrainer instance. The loop runs for num_total_batches iterations, with each iteration processing a full batch of prompts through the complete PPO pipeline.
Theoretical Basis
Phase 1: Rollout Generation
In the rollout phase, the policy model generates responses for a batch of prompts. Generation uses sampling (not greedy) with the configured temperature and top-k/top-p settings. Key steps:
- Batch generation: Queries are processed in sub-batches of size local_rollout_forward_batch_size to manage memory.
- Log-probability computation: For each generated token, the log-probability under the current policy and reference policy is computed.
- Response truncation: Responses are truncated at the first occurrence of the stop token, with remaining positions filled with pad tokens.
- Reward scoring: The reward model scores the truncated responses.
- Value estimation: The value model estimates state values at each token position.
Phase 2: KL-Penalized Rewards
The raw reward from the reward model is augmented with a per-token KL divergence penalty:
reward_total[t] = -kl_coef * KL[t] (for all tokens)
reward_total[last_token] += score (add the sequence-level reward at the final position)
The KL divergence is computed per token using the selected estimator:
- k1 estimator:
KL = log(ref_prob / policy_prob) = ref_logprob - policy_logprob(negated as-logr) - k3 estimator:
KL = (exp(logr) - 1) - logr(lower variance)
An optional missing_eos_penalty is applied to responses that fail to generate an EOS token, encouraging the model to produce complete responses.
Phase 3: GAE Advantage Estimation
Generalized Advantage Estimation computes the advantage function using a recursive formula:
delta[t] = reward[t] + gamma * V[t+1] - V[t]
A[t] = delta[t] + gamma * lambda * A[t+1]
where:
- delta[t] is the temporal difference error at step t.
- gamma is the discount factor (default 1.0 for text generation).
- lambda is the GAE parameter (default 0.95) controlling the bias-variance tradeoff.
The returns (targets for the value function) are computed as:
returns = advantages + values
Advantages are whitened (normalized to zero mean and unit variance) across the non-padded positions to stabilize training.
Phase 4: Clipped Policy and Value Optimization
For each batch of rollout data, num_ppo_epochs optimization passes are performed. In each pass, the data is shuffled and split into minibatches. For each minibatch:
Clipped Policy Loss:
ratio = exp(new_logprob - old_logprob)
pg_loss1 = -advantage * ratio
pg_loss2 = -advantage * clip(ratio, 1-epsilon, 1+epsilon)
policy_loss = max(pg_loss1, pg_loss2)
The clipping prevents the policy from moving too far from its behavior during rollout generation.
Clipped Value Loss:
vpred_clipped = clip(vpred, old_values - epsilon_v, old_values + epsilon_v)
vf_loss = 0.5 * max((vpred - returns)^2, (vpred_clipped - returns)^2)
Total Loss:
total_loss = policy_loss + vf_coef * value_loss
Training Metrics
| Category | Metric | Description |
|---|---|---|
| Objective | kl | Mean KL divergence from reference policy |
| Objective | entropy | Mean entropy of policy distribution |
| Objective | non_score_reward | Mean per-token KL penalty |
| Objective | rlhf_reward | Combined reward (non_score_reward + scores) |
| Objective | scores | Mean reward model scores |
| Policy | approxkl_avg | Approximate KL between old and new policy (within PPO epochs) |
| Policy | clipfrac_avg | Fraction of policy ratios clipped |
| Loss | policy_avg | Mean policy gradient loss |
| Loss | value_avg | Mean value function loss |
| Value | clipfrac_avg | Fraction of value predictions clipped |
| Value | ratio | Mean importance sampling ratio |
| Value | num_eos_tokens | Number of responses containing EOS |