Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct Compute GRPO Loss

From Leeroopedia


Type Function
Source open_instruct/grpo_utils.py:L235-270
Dependencies torch, open_instruct.model_utils
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete function for computing the GRPO clipped policy gradient loss with optional KL penalty, provided by the Open Instruct library.

Description

compute_grpo_loss() computes the per-token policy gradient loss using either the DAPO or CISPO variant, depending on the experiment configuration. It takes the new policy's log-probabilities, the importance sampling ratio, per-token advantages, and optionally the reference model's log-probabilities, and returns the individual loss components.

The function supports:

  • DAPO loss: Asymmetric PPO-style clipping with configurable lower and upper clip bounds.
  • CISPO loss: Ratio-clipped REINFORCE-style loss with detached importance weights.
  • Truncated importance sampling: Optional weighting by truncated importance sampling ratios.
  • KL penalty: When reference log-probabilities are provided, computes KL divergence using a configurable estimator (0-3).

Usage

Called once per mini-batch during the training step, inside the gradient computation loop. The returned loss components are aggregated across tokens and micro-batches before the optimizer step.

Code Reference

Source Location

Signature

def compute_grpo_loss(
    new_logprobs: torch.Tensor,
    ratio: torch.Tensor,
    advantages: torch.Tensor,
    ref_logprobs: torch.Tensor | None,
    config: ExperimentConfig,
    tis_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

Import

from open_instruct.grpo_utils import compute_grpo_loss

I/O Contract

Inputs

Name Type Description
new_logprobs torch.Tensor Per-token log-probabilities from the current policy's forward pass. Shape: (batch, seq_len).
ratio torch.Tensor Importance sampling ratio exp(new_logprobs - old_logprobs). Shape: (batch, seq_len).
advantages torch.Tensor Per-token advantages (broadcast from per-completion group-relative advantages). Shape: (batch, seq_len).
ref_logprobs None Per-token log-probabilities from the reference model. None if no reference model is used.
config ExperimentConfig Experiment configuration containing loss_fn, clip_lower, clip_higher, kl_estimator, and beta.
tis_weights None Optional truncated importance sampling weights for per-token loss scaling.

Outputs

Name Type Description
pg_losses torch.Tensor Unclipped policy gradient losses. Shape: (batch, seq_len).
pg_losses2 torch.Tensor Clipped policy gradient losses. Shape: (batch, seq_len).
pg_loss_max torch.Tensor Element-wise max of clipped and unclipped losses (the final PPO-style loss). Shape: (batch, seq_len).
kl torch.Tensor Per-token KL divergence estimate (zeros if no reference model). Shape: (batch, seq_len).

Usage Examples

import torch
from open_instruct.grpo_utils import compute_grpo_loss, ExperimentConfig

config = ExperimentConfig(
    loss_fn="dapo",
    clip_lower=0.2,
    clip_higher=0.28,
    beta=0.05,
    kl_estimator=2,
)

# Simulated tensors for a batch of 2 sequences, 10 tokens each
new_logprobs = torch.randn(2, 10)
old_logprobs = torch.randn(2, 10)
ratio = torch.exp(new_logprobs - old_logprobs)
advantages = torch.randn(2, 10)
ref_logprobs = torch.randn(2, 10)

pg_losses, pg_losses2, pg_loss_max, kl = compute_grpo_loss(
    new_logprobs=new_logprobs,
    ratio=ratio,
    advantages=advantages,
    ref_logprobs=ref_logprobs,
    config=config,
)

# Final loss per token = pg_loss_max + beta * kl
total_loss = pg_loss_max + config.beta * kl
mean_loss = total_loss.mean()
mean_loss.backward()

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment