Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers DPO Forward

From Leeroopedia


Implementation: DPO_Forward

Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (x-transformers)
Domains NLP, Alignment, RLHF
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for computing DPO preference alignment loss provided by the x-transformers library.

Description

DPO.forward() computes the DPO loss from preference pairs. The method performs the following steps:

  1. Validation: Asserts that preferred_seq is 2-dimensional and that both sequences have the same shape.
  2. Automatic padding masks: If pad_id was set during initialization and sequence masks are not provided, creates boolean masks where tokens not equal to pad_id are True.
  3. Reference model log probabilities: With torch.no_grad(), sets the reference model to eval mode and computes per-token log probabilities for both the preferred and unpreferred sequences using the frozen reference model.
  4. Policy model log probabilities: Computes per-token log probabilities for both sequences using the trainable policy model (gradients are tracked).
  5. Mask construction: Combines the prompt mask (inverted, since True means prompt token to exclude) with any padding masks using logical AND.
  6. Masked mean log probabilities: Computes the mean log probability over non-masked tokens for each sequence under each model. The masked_mean() helper handles the case where the mask has one more token than the log probs (due to the input/target offset).
  7. DPO loss computation: Computes policy_logratios = policy_preferred - policy_unpreferred, ref_logratios = ref_preferred - ref_unpreferred, and returns mean(-log_sigmoid(β · (policy_logratios - ref_logratios))).

Code Reference

Source Location

x-transformers repo, file: x_transformers/dpo.py, lines L71-117.

Signature

def forward(
    self,
    preferred_seq,
    unpreferred_seq,
    *,
    prompt_mask,
    preferred_seq_mask = None,
    unpreferred_seq_mask = None,
) -> Tensor:

Import

from x_transformers.dpo import DPO

I/O Contract

Inputs

Parameter Type Required Default Description
preferred_seq Tensor (B, seq_len) Yes -- Batch of preferred (winner) token sequences. Each sequence contains the prompt followed by the preferred completion.
unpreferred_seq Tensor (B, seq_len) Yes -- Batch of unpreferred (loser) token sequences. Must have the same shape as preferred_seq.
prompt_mask Tensor (B, seq_len) Yes -- Boolean mask where True indicates prompt tokens that should be excluded from the loss computation.
preferred_seq_mask Tensor (B, seq_len) or None No None Optional padding mask for preferred sequences. True indicates valid (non-padding) tokens. Auto-created from pad_id if not provided.
unpreferred_seq_mask Tensor (B, seq_len) or None No None Optional padding mask for unpreferred sequences. True indicates valid (non-padding) tokens. Auto-created from pad_id if not provided.

Outputs

Output Type Description
loss Tensor (scalar) The mean DPO loss over the batch: mean(-log_sigmoid(β · (policy_logratios - ref_logratios))). This scalar can be directly backpropagated through to update the policy model.

Usage Examples

Basic DPO Training Step

import torch

# Preference training step
preferred = torch.randint(0, 256, (8, 512)).cuda()
unpreferred = torch.randint(0, 256, (8, 512)).cuda()
prompt_mask = torch.zeros(8, 512).bool().cuda()
prompt_mask[:, :64] = True  # First 64 tokens are prompt

loss = dpo(preferred, unpreferred, prompt_mask=prompt_mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()

With Explicit Padding Masks

# When sequences have variable-length completions padded to max length
preferred_mask = preferred != 0  # 0 is pad token
unpreferred_mask = unpreferred != 0

loss = dpo(
    preferred,
    unpreferred,
    prompt_mask = prompt_mask,
    preferred_seq_mask = preferred_mask,
    unpreferred_seq_mask = unpreferred_mask,
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment