Implementation:Allenai Open instruct DPO Compute Loss
| Component Type | Function |
|---|---|
| Source | open_instruct/dpo_utils.py (Lines 730-771)
|
| Repository | Open Instruct |
| Dependencies | torch, open_instruct.model_utils |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for dispatching DPO loss computation to the appropriate loss variant based on the experiment configuration, provided by the Open Instruct library.
Description
compute_loss() is the central dispatcher that routes loss computation to the correct DPO variant function. Based on the loss_type field in the DPOConfig, it:
- For
dpoanddpo_norm: Retrieves cached reference log-probabilities from thereference_cacheusing the batch's sample indices, then callsdpo_loss()with the policy and reference logprobs, beta, and label smoothing.
- For
simpo: Callssimpo_loss()directly with only the policy logprobs, beta, and gamma/beta ratio. No reference cache is needed.
- For
wpo: Retrieves cached reference log-probabilities, extracts token-level loss masks from the batch labels (tokens wherelabels != -100), and callswpo_loss()with all inputs including the loss masks.
If an unknown loss type is provided, the function raises a ValueError.
The distinction between dpo and dpo_norm is handled upstream: the training loop passes average_log_prob=True for dpo_norm during the forward pass, so the policy and reference logprobs are already averaged when they reach compute_loss(). The loss function itself is identical for both.
Usage
Import and call compute_loss() inside the DPO training loop to compute losses for any supported DPO variant. The calling code does not need to know which variant is active -- the dispatch is handled internally based on the configuration.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/dpo_utils.py(Lines 730-771)
Signature
def compute_loss(
args: DPOConfig,
batch: dict[str, torch.Tensor],
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_cache: model_utils.TensorCache | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Import
from open_instruct.dpo_utils import compute_loss
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
args |
DPOConfig |
DPO configuration containing loss_type, beta, gamma_beta_ratio, and label_smoothing.
|
batch |
dict[str, torch.Tensor] |
Training batch dictionary. Must contain:
|
policy_chosen_logps |
torch.Tensor |
Policy model log-probabilities for chosen responses. Shape: (batch_size,).
|
policy_rejected_logps |
torch.Tensor |
Policy model log-probabilities for rejected responses. Shape: (batch_size,).
|
reference_cache |
TensorCache or None |
Precomputed reference logprobs cache. Required for dpo, dpo_norm, and wpo. Can be None for simpo.
|
Outputs
| Output | Type | Description |
|---|---|---|
losses |
torch.Tensor |
Per-example losses. Shape: (batch_size,).
|
chosen_rewards |
torch.Tensor |
Implicit rewards for chosen responses (detached). Shape: (batch_size,).
|
rejected_rewards |
torch.Tensor |
Implicit rewards for rejected responses (detached). Shape: (batch_size,).
|
Dispatch Table
| Loss Type | Function Called | Reference Cache Required | Extra Inputs |
|---|---|---|---|
dpo |
dpo_loss() |
Yes | -- |
dpo_norm |
dpo_loss() |
Yes | -- (averaging handled in forward pass) |
simpo |
simpo_loss() |
No | gamma_beta_ratio
|
wpo |
wpo_loss() |
Yes | chosen_loss_mask, rejected_loss_mask
|
Usage Examples
from open_instruct.dpo_utils import compute_loss, DPOConfig, DPOLossType
from open_instruct.model_utils import TensorCache
# Example: computing DPO loss inside the training loop
args = DPOConfig(loss_type=DPOLossType.dpo, beta=0.1, label_smoothing=0.0)
# Forward pass (done by the training loop)
policy_chosen_logps, policy_rejected_logps, _ = forward_fn(model, batch)
# Compute loss (dispatches to dpo_loss internally)
losses, chosen_rewards, rejected_rewards = compute_loss(
args=args,
batch=batch,
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_cache=reference_cache,
)
loss = losses.mean()
loss.backward()
# Example: SimPO (no reference cache needed)
args = DPOConfig(loss_type=DPOLossType.simpo, beta=2.0, gamma_beta_ratio=0.3)
losses, chosen_rewards, rejected_rewards = compute_loss(
args=args,
batch=batch,
policy_chosen_logps=policy_chosen_logps,
policy_rejected_logps=policy_rejected_logps,
reference_cache=None,
)