Implementation:Allenai Open instruct Compute GRPO Loss
| Type | Function |
|---|---|
| Source | open_instruct/grpo_utils.py:L235-270
|
| Dependencies | torch, open_instruct.model_utils |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete function for computing the GRPO clipped policy gradient loss with optional KL penalty, provided by the Open Instruct library.
Description
compute_grpo_loss() computes the per-token policy gradient loss using either the DAPO or CISPO variant, depending on the experiment configuration. It takes the new policy's log-probabilities, the importance sampling ratio, per-token advantages, and optionally the reference model's log-probabilities, and returns the individual loss components.
The function supports:
- DAPO loss: Asymmetric PPO-style clipping with configurable lower and upper clip bounds.
- CISPO loss: Ratio-clipped REINFORCE-style loss with detached importance weights.
- Truncated importance sampling: Optional weighting by truncated importance sampling ratios.
- KL penalty: When reference log-probabilities are provided, computes KL divergence using a configurable estimator (0-3).
Usage
Called once per mini-batch during the training step, inside the gradient computation loop. The returned loss components are aggregated across tokens and micro-batches before the optimizer step.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/grpo_utils.py
Signature
def compute_grpo_loss(
new_logprobs: torch.Tensor,
ratio: torch.Tensor,
advantages: torch.Tensor,
ref_logprobs: torch.Tensor | None,
config: ExperimentConfig,
tis_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Import
from open_instruct.grpo_utils import compute_grpo_loss
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
new_logprobs |
torch.Tensor |
Per-token log-probabilities from the current policy's forward pass. Shape: (batch, seq_len).
|
ratio |
torch.Tensor |
Importance sampling ratio exp(new_logprobs - old_logprobs). Shape: (batch, seq_len).
|
advantages |
torch.Tensor |
Per-token advantages (broadcast from per-completion group-relative advantages). Shape: (batch, seq_len).
|
ref_logprobs |
None | Per-token log-probabilities from the reference model. None if no reference model is used. |
config |
ExperimentConfig |
Experiment configuration containing loss_fn, clip_lower, clip_higher, kl_estimator, and beta.
|
tis_weights |
None | Optional truncated importance sampling weights for per-token loss scaling. |
Outputs
| Name | Type | Description |
|---|---|---|
pg_losses |
torch.Tensor |
Unclipped policy gradient losses. Shape: (batch, seq_len).
|
pg_losses2 |
torch.Tensor |
Clipped policy gradient losses. Shape: (batch, seq_len).
|
pg_loss_max |
torch.Tensor |
Element-wise max of clipped and unclipped losses (the final PPO-style loss). Shape: (batch, seq_len).
|
kl |
torch.Tensor |
Per-token KL divergence estimate (zeros if no reference model). Shape: (batch, seq_len).
|
Usage Examples
import torch
from open_instruct.grpo_utils import compute_grpo_loss, ExperimentConfig
config = ExperimentConfig(
loss_fn="dapo",
clip_lower=0.2,
clip_higher=0.28,
beta=0.05,
kl_estimator=2,
)
# Simulated tensors for a batch of 2 sequences, 10 tokens each
new_logprobs = torch.randn(2, 10)
old_logprobs = torch.randn(2, 10)
ratio = torch.exp(new_logprobs - old_logprobs)
advantages = torch.randn(2, 10)
ref_logprobs = torch.randn(2, 10)
pg_losses, pg_losses2, pg_loss_max, kl = compute_grpo_loss(
new_logprobs=new_logprobs,
ratio=ratio,
advantages=advantages,
ref_logprobs=ref_logprobs,
config=config,
)
# Final loss per token = pg_loss_max + beta * kl
total_loss = pg_loss_max + config.beta * kl
mean_loss = total_loss.mean()
mean_loss.backward()