Implementation:Lucidrains X transformers DPO Forward
Appearance
Implementation: DPO_Forward
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (x-transformers) |
| Domains | NLP, Alignment, RLHF |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for computing DPO preference alignment loss provided by the x-transformers library.
Description
DPO.forward() computes the DPO loss from preference pairs. The method performs the following steps:
- Validation: Asserts that
preferred_seqis 2-dimensional and that both sequences have the same shape. - Automatic padding masks: If
pad_idwas set during initialization and sequence masks are not provided, creates boolean masks where tokens not equal topad_idareTrue. - Reference model log probabilities: With
torch.no_grad(), sets the reference model to eval mode and computes per-token log probabilities for both the preferred and unpreferred sequences using the frozen reference model. - Policy model log probabilities: Computes per-token log probabilities for both sequences using the trainable policy model (gradients are tracked).
- Mask construction: Combines the prompt mask (inverted, since
Truemeans prompt token to exclude) with any padding masks using logical AND. - Masked mean log probabilities: Computes the mean log probability over non-masked tokens for each sequence under each model. The
masked_mean()helper handles the case where the mask has one more token than the log probs (due to the input/target offset). - DPO loss computation: Computes
policy_logratios = policy_preferred - policy_unpreferred,ref_logratios = ref_preferred - ref_unpreferred, and returnsmean(-log_sigmoid(β · (policy_logratios - ref_logratios))).
Code Reference
Source Location
x-transformers repo, file: x_transformers/dpo.py, lines L71-117.
Signature
def forward(
self,
preferred_seq,
unpreferred_seq,
*,
prompt_mask,
preferred_seq_mask = None,
unpreferred_seq_mask = None,
) -> Tensor:
Import
from x_transformers.dpo import DPO
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
preferred_seq |
Tensor (B, seq_len) | Yes | -- | Batch of preferred (winner) token sequences. Each sequence contains the prompt followed by the preferred completion. |
unpreferred_seq |
Tensor (B, seq_len) | Yes | -- | Batch of unpreferred (loser) token sequences. Must have the same shape as preferred_seq.
|
prompt_mask |
Tensor (B, seq_len) | Yes | -- | Boolean mask where True indicates prompt tokens that should be excluded from the loss computation.
|
preferred_seq_mask |
Tensor (B, seq_len) or None | No | None |
Optional padding mask for preferred sequences. True indicates valid (non-padding) tokens. Auto-created from pad_id if not provided.
|
unpreferred_seq_mask |
Tensor (B, seq_len) or None | No | None |
Optional padding mask for unpreferred sequences. True indicates valid (non-padding) tokens. Auto-created from pad_id if not provided.
|
Outputs
| Output | Type | Description |
|---|---|---|
| loss | Tensor (scalar) | The mean DPO loss over the batch: mean(-log_sigmoid(β · (policy_logratios - ref_logratios))). This scalar can be directly backpropagated through to update the policy model.
|
Usage Examples
Basic DPO Training Step
import torch
# Preference training step
preferred = torch.randint(0, 256, (8, 512)).cuda()
unpreferred = torch.randint(0, 256, (8, 512)).cuda()
prompt_mask = torch.zeros(8, 512).bool().cuda()
prompt_mask[:, :64] = True # First 64 tokens are prompt
loss = dpo(preferred, unpreferred, prompt_mask=prompt_mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()
With Explicit Padding Masks
# When sequences have variable-length completions padded to max length
preferred_mask = preferred != 0 # 0 is pad token
unpreferred_mask = unpreferred != 0
loss = dpo(
preferred,
unpreferred,
prompt_mask = prompt_mask,
preferred_seq_mask = preferred_mask,
unpreferred_seq_mask = unpreferred_mask,
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment