Principle:Allenai Open instruct DPO Loss Dispatch
| Knowledge Sources | |
|---|---|
| Domains | Software Engineering, Machine Learning, Design Patterns |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
DPO loss dispatch is the pattern of routing a training step's loss computation to the appropriate DPO loss variant based on a runtime configuration parameter, enabling a single training loop to support multiple preference optimization objectives.
Description
The DPO family of loss functions shares a common interface -- all variants accept policy log-probabilities (and optionally reference log-probabilities) and return a tuple of per-example losses, chosen rewards, and rejected rewards. However, each variant computes the loss differently:
Standard DPO (dpo):
Uses summed log-probabilities from both the policy and reference models. The loss is the negative log-sigmoid of the scaled difference between policy and reference log-ratios. Requires a reference model (or cached reference logprobs).
Normalized DPO (dpo_norm):
Identical to standard DPO but uses averaged (per-token) log-probabilities instead of summed. This normalizes for response length, preventing the model from favoring shorter responses. Also requires a reference model.
SimPO (simpo):
Eliminates the reference model entirely. Uses the policy's averaged log-probabilities with a margin term (gamma_beta_ratio) as the target gap between chosen and rejected. The loss compares (policy_chosen_avg - policy_rejected_avg - gamma/beta).
WPO (wpo):
Extends standard DPO by weighting the loss with a confidence factor derived from the policy model's average log-probabilities on both chosen and rejected responses. The weight is clamped to [0, 1], focusing the training signal on less-confident examples. Requires a reference model and access to token-level loss masks.
The dispatch pattern centralizes the routing logic in a single function, so the training loop does not need to contain variant-specific branching. This follows the Strategy Pattern from software design, where the loss function is selected at configuration time and applied uniformly during training.
Usage
Use the loss dispatch pattern when:
- You want a single training script that supports multiple DPO loss variants.
- You want to switch between loss functions purely through configuration, without code changes.
- You are adding new loss variants and want to maintain a consistent interface.
Theoretical Basis
All DPO variants share the Bradley-Terry preference model:
They differ in how they define the implicit reward :
| Variant | Implicit Reward | Reference Model Required |
|---|---|---|
| DPO | (sum logprobs) | Yes |
| DPO-Norm | (avg logprobs) | Yes |
| SimPO | No | |
| WPO | (weighted) | Yes |
The dispatch function selects the appropriate reward definition based on the configured loss_type and passes the correct arguments to the corresponding loss function:
function compute_loss(config, batch, policy_chosen_logps, policy_rejected_logps, reference_cache):
match config.loss_type:
case "dpo" | "dpo_norm":
ref_logps = reference_cache[batch.indices]
return dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_logps.chosen, ref_logps.rejected, beta=config.beta)
case "simpo":
return simpo_loss(policy_chosen_logps, policy_rejected_logps,
beta=config.beta, gamma_beta_ratio=config.gamma_beta_ratio)
case "wpo":
ref_logps = reference_cache[batch.indices]
return wpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_logps.chosen, ref_logps.rejected, beta=config.beta,
chosen_mask=batch.chosen_labels != -100,
rejected_mask=batch.rejected_labels != -100)
The DPOLossType enum also exposes computed properties:
is_average_loss: Whether this variant uses averaged (per-token) log-probabilities. True forsimpoanddpo_norm.needs_reference_model: Whether the variant requires reference logprobs. True fordpo,dpo_norm, andwpo.computes_reward_metrics: Whether reward accuracy/margin metrics are meaningful. True fordpoanddpo_norm.