Principle:Huggingface Trl GRPO Training Loop
| Property | Value |
|---|---|
| Principle Name | GRPO Generation, Scoring, and Training Loop |
| Library | Huggingface TRL |
| Category | Online RL Training Loop |
| Paper | GRPO (DeepSeekMath) |
| Related Papers | DAPO, Dr. GRPO, CISPO (MiniMax-M1), SAPO |
Overview
Description
The GRPO training loop is the core of the online RL pipeline. Each training step consists of three phases: (1) generation -- the policy model generates multiple completions per prompt; (2) scoring -- reward functions evaluate completions and advantages are computed using group-relative normalization; (3) policy update -- the model is updated using a clipped surrogate objective that maximizes the advantage-weighted log-probability of good completions.
This loop repeats for each batch, with completions being regenerated periodically (controlled by steps_per_generation and num_iterations). Between regenerations, the same set of completions is reused across multiple gradient accumulation steps, amortizing the cost of generation.
Usage
Once the GRPOTrainer is initialized, the training loop is launched with a single call:
trainer.train()
The train() method is inherited from the Hugging Face Trainer. The GRPO-specific logic is implemented in the overridden methods _prepare_inputs, _generate_and_score_completions, and _compute_loss.
Theoretical Basis
Phase 1: Generation
For each prompt in the batch, the model generates G completions (controlled by num_generations). Generation can happen via three backends:
- Transformers generate: The standard
model.generate()path using left-padded inputs and sampling with the configured temperature, top_p, top_k, and min_p parameters. - Transformers paged: The continuous-batching
model.generate_batch()API for memory-efficient generation. - vLLM: Offloaded generation via the vLLM backend (server or colocated mode) for significantly faster throughput.
When using vLLM, model weights are synchronized from the training process to the vLLM engine before each generation batch. The vLLM backend also returns per-token log-probabilities from the sampling process, which are used for importance sampling correction.
For agentic training with tools, generation enters a multi-turn loop: after the model generates a response with tool calls, the tools are executed, results are appended to the prompt, and the model generates again. This continues until no tool calls are made or the iteration limit is reached.
Phase 2: Scoring and Advantage Computation
After generation, completions are scored by all configured reward functions. The rewards are gathered across distributed processes (since completions for the same group may be split across GPUs) and then aggregated:
- Per-function rewards: Each reward function produces a scalar per completion.
Nonevalues indicate non-applicable rewards. - Weighted aggregation: Rewards are combined using
reward_weightsandmulti_objective_aggregation. - Group-relative normalization: Advantages are computed by subtracting the group mean and (optionally) dividing by the group standard deviation.
The aggregation strategy determines the normalization order:
- sum_then_normalize: Sum weighted rewards first, then normalize within groups
- normalize_then_sum: Normalize each reward function independently within groups, then sum (GDPO approach)
Phase 3: Policy Update (Loss Computation)
The policy gradient loss uses a clipped surrogate objective inspired by PPO but adapted for the group-relative setting. The loss computation involves:
- Log-probability computation: Forward pass through the model to get per-token log-probabilities for the generated completions.
- Importance sampling ratio: The ratio between current and old (at generation time) log-probabilities:
r(t) = exp(log_pi_new - log_pi_old). - Clipped objective: The standard PPO-style clipping:
min(r(t) * A, clip(r(t), 1-eps, 1+eps) * A). - Loss aggregation: Depends on
loss_type:- grpo: Per-sequence mean, then batch mean (has length bias)
- dr_grpo: Global sum divided by
batch_size * max_completion_length - dapo: Global sum divided by total active tokens across the accumulated batch (default, eliminates length bias)
- bnpo: Local sum divided by local active tokens
- cispo: Clips the importance weights directly rather than the advantage-scaled product
- sapo: Replaces hard clipping with sigmoid-based soft gating
When beta > 0, a KL divergence penalty term is added to the loss: KL = exp(log_ref - log_pi) - (log_ref - log_pi) - 1.
Importance Sampling Correction
When using vLLM, the log-probabilities from the vLLM engine may differ from those computed by the training model (due to numerical differences, different attention implementations, etc.). TRL applies importance sampling correction to account for this mismatch, with configurable modes:
- token_truncate / token_mask: Per-token IS ratios, either capped or masked when exceeding the threshold
- sequence_truncate / sequence_mask: Sequence-level IS ratios applied uniformly to all tokens
Batch Buffering and Multi-Iteration
To amortize generation cost, the trainer generates a large batch of completions once every steps_per_generation * num_iterations steps. This batch is split into smaller sub-batches that are consumed across multiple gradient accumulation steps. The num_iterations parameter (denoted mu in the GRPO paper) controls how many times the same batch of completions is reused for policy updates.