Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl PPOTrainer Train

From Leeroopedia


Property Value
Implementation Name PPOTrainer Train
Technology Huggingface TRL
Type API Doc
Workflow PPO RLHF Training
Paper PPO (https://arxiv.org/abs/1707.06347), GAE (https://arxiv.org/abs/1506.02438)
Principle Principle:Huggingface_Trl_PPO_Training_Loop

Overview

Description

The PPOTrainer.train method implements the complete PPO training loop for RLHF. Each iteration performs: (1) rollout generation via the policy model, (2) KL-penalized reward computation, (3) GAE advantage estimation, and (4) multiple PPO optimization epochs with clipped policy and value losses. The method manages memory aggressively by clearing caches between phases.

Usage

Called on an initialized PPOTrainer instance. The method runs for num_total_batches iterations, logging metrics and optionally generating sample completions at regular intervals.

Code Reference

Source Location

trl/experimental/ppo/ppo_trainer.py lines 596-940

Signature

def train(self) -> None:
    """
    Main PPO training loop.

    Phase structure per iteration:
    - Lines 664-745: Rollout generation (response generation, logprob computation,
      reward scoring, value estimation)
    - Lines 765-771: KL-penalized reward computation
    - Lines 779-791: GAE advantage estimation
    - Lines 794-872: PPO optimization epochs (clipped policy + value loss)
    - Lines 873-910: Metrics logging and checkpoint saving
    """

Key Internal Functions

def generate(lm_backbone, queries, pad_token_id, generation_config):
    """Generate responses from the policy model."""
    # Returns (query_responses, logits)

def batch_generation(model, queries, batch_size, pad_token_id, generation_config):
    """Generate in sub-batches to manage memory."""
    # Returns (padded_query_responses, padded_logitss)

def forward(model, query_responses, pad_token_id):
    """Forward pass with proper attention mask and position IDs."""
    # Returns ModelOutput with hidden states

def truncate_response(stop_token_id, pad_token_id, responses):
    """Truncate responses at stop token, fill rest with pad."""
    # Returns truncated response tensor

Import

# train() is a method on PPOTrainer, invoked as:
from trl.experimental.ppo import PPOTrainer
trainer = PPOTrainer(...)
trainer.train()

I/O Contract

Phase 1: Rollout Generation

Input Type Description
queries torch.Tensor Batch of tokenized prompts from the dataloader
generation_config GenerationConfig Sampling config (temperature, max_new_tokens, top_k=0, top_p=1)
Output Type Shape Description
query_responses torch.Tensor (B, context_len + response_len) Concatenated prompt + generated response
responses torch.Tensor (B, response_len) Generated response tokens only
logprobs torch.Tensor (B, response_len) Log-probabilities under current policy
ref_logprobs torch.Tensor (B, response_len) Log-probabilities under reference policy
scores torch.Tensor (B,) Reward model scores for postprocessed responses
values torch.Tensor (B, response_len) Value model estimates at each token position

Phase 2: KL-Penalized Reward

Computation Formula
KL divergence (k1) kl = -(ref_logprobs - logprobs)
KL divergence (k3) logr = ref_logprobs - logprobs; kl = (exp(logr) - 1) - logr
Non-score reward non_score_reward = -kl_coef * kl
Final reward rewards = non_score_reward; rewards[last_token] += score

Phase 3: GAE

Computation Formula
TD error delta[t] = rewards[t] + gamma * values[t+1] - values[t]
GAE advantage A[t] = delta[t] + gamma * lam * A[t+1]
Returns returns = advantages + values
Whitened advantages advantages = whiten(advantages, mask=~padding_mask)

Phase 4: PPO Optimization

Loss Component Formula
Importance ratio ratio = exp(new_logprob - old_logprob)
Clipped policy loss max(-adv * ratio, -adv * clip(ratio, 1-eps, 1+eps))
Clipped value loss 0.5 * max((vpred - ret)^2, (clip(vpred) - ret)^2)
Total loss policy_loss + vf_coef * value_loss

Logged Metrics

Metric Key Description
objective/kl Mean KL divergence from reference
objective/entropy Mean policy entropy
objective/non_score_reward Mean KL penalty
objective/rlhf_reward Total reward (non_score + score)
objective/scores Mean reward model scores
policy/approxkl_avg Approximate KL between old and new policy
policy/clipfrac_avg Fraction of clipped policy ratios
loss/policy_avg Mean policy loss
loss/value_avg Mean value loss
val/clipfrac_avg Fraction of clipped value predictions
val/ratio Mean importance sampling ratio
val/ratio_var Variance of importance sampling ratio
val/num_eos_tokens Count of generated EOS tokens
lr Current learning rate
episode Current episode count

Usage Examples

Basic Training

from trl.experimental.ppo import PPOTrainer, PPOConfig

# Assuming trainer is already initialized
trainer = PPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy,
    ref_model=ref_policy,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Run the full PPO training loop
trainer.train()

# Save the trained policy
trainer.save_model("ppo-trained-policy")

With DeepSpeed ZeRO-3

accelerate launch --config_file deepspeed_zero3.yaml \
    examples/scripts/ppo/ppo.py \
    --num_ppo_epochs 1 \
    --num_mini_batches 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --total_episodes 10000 \
    --local_rollout_forward_batch_size 1

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment