Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct DPO Compute Loss

From Leeroopedia


Component Type Function
Source open_instruct/dpo_utils.py (Lines 730-771)
Repository Open Instruct
Dependencies torch, open_instruct.model_utils
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for dispatching DPO loss computation to the appropriate loss variant based on the experiment configuration, provided by the Open Instruct library.

Description

compute_loss() is the central dispatcher that routes loss computation to the correct DPO variant function. Based on the loss_type field in the DPOConfig, it:

  • For dpo and dpo_norm: Retrieves cached reference log-probabilities from the reference_cache using the batch's sample indices, then calls dpo_loss() with the policy and reference logprobs, beta, and label smoothing.
  • For simpo: Calls simpo_loss() directly with only the policy logprobs, beta, and gamma/beta ratio. No reference cache is needed.
  • For wpo: Retrieves cached reference log-probabilities, extracts token-level loss masks from the batch labels (tokens where labels != -100), and calls wpo_loss() with all inputs including the loss masks.

If an unknown loss type is provided, the function raises a ValueError.

The distinction between dpo and dpo_norm is handled upstream: the training loop passes average_log_prob=True for dpo_norm during the forward pass, so the policy and reference logprobs are already averaged when they reach compute_loss(). The loss function itself is identical for both.

Usage

Import and call compute_loss() inside the DPO training loop to compute losses for any supported DPO variant. The calling code does not need to know which variant is active -- the dispatch is handled internally based on the configuration.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/dpo_utils.py (Lines 730-771)

Signature

def compute_loss(
    args: DPOConfig,
    batch: dict[str, torch.Tensor],
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_cache: model_utils.TensorCache | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

Import

from open_instruct.dpo_utils import compute_loss

I/O Contract

Inputs

Parameter Type Description
args DPOConfig DPO configuration containing loss_type, beta, gamma_beta_ratio, and label_smoothing.
batch dict[str, torch.Tensor] Training batch dictionary. Must contain:
  • "index": Sample indices into the full dataset (for cache lookup).
  • "chosen_labels": Token-level labels for chosen responses (for WPO loss masks).
  • "rejected_labels": Token-level labels for rejected responses (for WPO loss masks).
policy_chosen_logps torch.Tensor Policy model log-probabilities for chosen responses. Shape: (batch_size,).
policy_rejected_logps torch.Tensor Policy model log-probabilities for rejected responses. Shape: (batch_size,).
reference_cache TensorCache or None Precomputed reference logprobs cache. Required for dpo, dpo_norm, and wpo. Can be None for simpo.

Outputs

Output Type Description
losses torch.Tensor Per-example losses. Shape: (batch_size,).
chosen_rewards torch.Tensor Implicit rewards for chosen responses (detached). Shape: (batch_size,).
rejected_rewards torch.Tensor Implicit rewards for rejected responses (detached). Shape: (batch_size,).

Dispatch Table

Loss Type Function Called Reference Cache Required Extra Inputs
dpo dpo_loss() Yes --
dpo_norm dpo_loss() Yes -- (averaging handled in forward pass)
simpo simpo_loss() No gamma_beta_ratio
wpo wpo_loss() Yes chosen_loss_mask, rejected_loss_mask

Usage Examples

from open_instruct.dpo_utils import compute_loss, DPOConfig, DPOLossType
from open_instruct.model_utils import TensorCache

# Example: computing DPO loss inside the training loop
args = DPOConfig(loss_type=DPOLossType.dpo, beta=0.1, label_smoothing=0.0)

# Forward pass (done by the training loop)
policy_chosen_logps, policy_rejected_logps, _ = forward_fn(model, batch)

# Compute loss (dispatches to dpo_loss internally)
losses, chosen_rewards, rejected_rewards = compute_loss(
    args=args,
    batch=batch,
    policy_chosen_logps=policy_chosen_logps,
    policy_rejected_logps=policy_rejected_logps,
    reference_cache=reference_cache,
)

loss = losses.mean()
loss.backward()
# Example: SimPO (no reference cache needed)
args = DPOConfig(loss_type=DPOLossType.simpo, beta=2.0, gamma_beta_ratio=0.3)

losses, chosen_rewards, rejected_rewards = compute_loss(
    args=args,
    batch=batch,
    policy_chosen_logps=policy_chosen_logps,
    policy_rejected_logps=policy_rejected_logps,
    reference_cache=None,
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment