Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Allenai Open instruct GRPO Policy Update

From Leeroopedia


Knowledge Sources
Domains Reinforcement Learning Optimization
Last Updated 2026-02-07 00:00 GMT

Overview

The GRPO policy update is the process of executing one optimization step on the policy model using packed training data, involving mini-batch iteration, gradient accumulation, loss computation, and distributed metric collection.

Description

Each GRPO training step consists of the following sub-steps:

  1. Data preparation: The data preparation actor generates completions, computes rewards, normalizes advantages, packs sequences, and distributes collated micro-batches to each learner GPU.
  2. Reference model inference: If a reference policy is loaded, compute log-probabilities of the packed sequences under the reference model (no gradients).
  3. Mini-batch training: The packed sequences are divided into num_mini_batches groups. Within each mini-batch, gradient accumulation is performed over accumulation_steps micro-batches.
  4. Loss computation: For each micro-batch, the forward pass computes new log-probabilities. The importance sampling ratio is computed against the old (vLLM) log-probabilities. The GRPO loss function produces per-token losses and KL penalties.
  5. Token-count normalization: The loss denominator is computed by all-reducing token counts across data-parallel ranks, ensuring consistent loss scaling regardless of packing efficiency variations across ranks.
  6. Gradient clipping: After accumulation, gradients are clipped to max_grad_norm.
  7. Optimizer step: The DeepSpeed engine performs the optimizer step with the learning rate scheduler.
  8. Weight synchronization: Updated weights are broadcast to vLLM inference engines.
  9. Reference policy update: If configured, the reference policy is updated via Polyak averaging.
  10. Metric aggregation: Training metrics (loss, KL, clip fraction, ratio, entropy) are token-weighted averaged across all workers.

Usage

The policy update is called once per training step by the main training loop. It is the most compute-intensive component of the training step (excluding generation). The number of mini-batches and gradient accumulation steps control the memory-compute tradeoff.

Theoretical Basis

Mini-Batch Training

GRPO supports multiple epochs over the same rollout data via num_epochs and num_mini_batches:

For each epoch (typically 1):
    For each mini-batch (group of accumulation_steps micro-batches):
        accumulated_loss = 0
        For each micro-batch in the group:
            new_logprobs = policy.forward(micro_batch)
            old_logprobs = vllm_logprobs[micro_batch]  # from generation
            ratio = exp(new_logprobs - old_logprobs)
            loss, kl = compute_grpo_loss(new_logprobs, ratio, advantages, ref_logprobs)
            (loss + beta * kl).backward()
            accumulated_loss += loss

        clip_grad_norm(policy, max_grad_norm)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

Token-Count Normalization

The loss is normalized by total token count across all ranks to ensure consistent gradient magnitudes:

For accumulation group g:
    local_tokens = sum(response_mask.sum() for micro_batch in group_g)
    global_tokens = all_reduce(local_tokens, op=SUM)
    loss_denominator = global_tokens  (or a fixed constant per Dr GRPO)

This all-reduce ensures that if one rank has fewer tokens due to packing variations, the loss is still normalized consistently.

Gradient Clipping

Gradient clipping prevents exploding gradients from high-advantage outlier tokens:

total_norm = sqrt(sum(param.grad.norm()^2 for all params))
if total_norm > max_grad_norm:
    scale = max_grad_norm / total_norm
    for param in all_params:
        param.grad *= scale

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment