Principle:Allenai Open instruct GRPO Policy Update
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning Optimization |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
The GRPO policy update is the process of executing one optimization step on the policy model using packed training data, involving mini-batch iteration, gradient accumulation, loss computation, and distributed metric collection.
Description
Each GRPO training step consists of the following sub-steps:
- Data preparation: The data preparation actor generates completions, computes rewards, normalizes advantages, packs sequences, and distributes collated micro-batches to each learner GPU.
- Reference model inference: If a reference policy is loaded, compute log-probabilities of the packed sequences under the reference model (no gradients).
- Mini-batch training: The packed sequences are divided into
num_mini_batchesgroups. Within each mini-batch, gradient accumulation is performed overaccumulation_stepsmicro-batches. - Loss computation: For each micro-batch, the forward pass computes new log-probabilities. The importance sampling ratio is computed against the old (vLLM) log-probabilities. The GRPO loss function produces per-token losses and KL penalties.
- Token-count normalization: The loss denominator is computed by all-reducing token counts across data-parallel ranks, ensuring consistent loss scaling regardless of packing efficiency variations across ranks.
- Gradient clipping: After accumulation, gradients are clipped to
max_grad_norm. - Optimizer step: The DeepSpeed engine performs the optimizer step with the learning rate scheduler.
- Weight synchronization: Updated weights are broadcast to vLLM inference engines.
- Reference policy update: If configured, the reference policy is updated via Polyak averaging.
- Metric aggregation: Training metrics (loss, KL, clip fraction, ratio, entropy) are token-weighted averaged across all workers.
Usage
The policy update is called once per training step by the main training loop. It is the most compute-intensive component of the training step (excluding generation). The number of mini-batches and gradient accumulation steps control the memory-compute tradeoff.
Theoretical Basis
Mini-Batch Training
GRPO supports multiple epochs over the same rollout data via num_epochs and num_mini_batches:
For each epoch (typically 1):
For each mini-batch (group of accumulation_steps micro-batches):
accumulated_loss = 0
For each micro-batch in the group:
new_logprobs = policy.forward(micro_batch)
old_logprobs = vllm_logprobs[micro_batch] # from generation
ratio = exp(new_logprobs - old_logprobs)
loss, kl = compute_grpo_loss(new_logprobs, ratio, advantages, ref_logprobs)
(loss + beta * kl).backward()
accumulated_loss += loss
clip_grad_norm(policy, max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
Token-Count Normalization
The loss is normalized by total token count across all ranks to ensure consistent gradient magnitudes:
For accumulation group g:
local_tokens = sum(response_mask.sum() for micro_batch in group_g)
global_tokens = all_reduce(local_tokens, op=SUM)
loss_denominator = global_tokens (or a fixed constant per Dr GRPO)
This all-reduce ensures that if one rank has fewer tokens due to packing variations, the loss is still normalized consistently.
Gradient Clipping
Gradient clipping prevents exploding gradients from high-advantage outlier tokens:
total_norm = sqrt(sum(param.grad.norm()^2 for all params))
if total_norm > max_grad_norm:
scale = max_grad_norm / total_norm
for param in all_params:
param.grad *= scale