Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Lucidrains X transformers Preference Dataset Pattern

From Leeroopedia


Field Value
Repo x-transformers
Domains Data_Engineering, NLP, Alignment
Last Updated 2026-02-08 18:00 GMT

Overview

Pattern specification for creating preference pair datasets for DPO alignment training with x-transformers.

Description

This is a Pattern Doc. The interface is derived from DPO.forward() parameters at dpo.py:L71-117. The forward method signature is:

def forward(
    self,
    preferred_seq,
    unpreferred_seq,
    *,
    prompt_mask,
    preferred_seq_mask=None,
    unpreferred_seq_mask=None,
):

The method asserts preferred_seq.ndim == 2 and preferred_seq.shape == unpreferred_seq.shape, requiring both sequences to be 2D tensors of identical shape.

Code Reference

File: x_transformers/dpo.py, Lines: L71-117 (DPO.forward())

Interface Specification

PreferenceDataset (Derived Interface)

class PreferenceDataset(Dataset):
    """Dataset for DPO training.

    Returns (preferred_seq, unpreferred_seq, prompt_mask) tuples.
    All tensors must have the same shape (seq_len,).
    prompt_mask: True where tokens are prompt (excluded from loss).
    """
    def __getitem__(self, index):
        preferred = self.data[index]['preferred']   # (seq_len,) LongTensor
        unpreferred = self.data[index]['unpreferred'] # (seq_len,) LongTensor
        prompt_len = self.data[index]['prompt_len']
        prompt_mask = torch.arange(len(preferred)) < prompt_len  # True = prompt
        return preferred, unpreferred, prompt_mask

Key details:

  • Both preferred and unpreferred must be the same length (the DPO.forward() method asserts shape equality).
  • The prompt_mask is True for prompt positions and False for completion positions. The DPO loss is computed only where prompt_mask is False.
  • If sequences have different lengths, they must be padded to the same length. Optional preferred_seq_mask and unpreferred_seq_mask parameters can be passed to DPO.forward() to indicate padding, or the pad_id constructor argument can be used for automatic mask generation.

Usage with DPO

from x_transformers import TransformerWrapper, Decoder
from x_transformers.dpo import DPO

# Create base model
model = TransformerWrapper(
    num_tokens=NUM_TOKENS,
    max_seq_len=SEQ_LEN,
    attn_layers=Decoder(dim=512, depth=6, heads=8)
)

# Wrap with DPO (creates policy + frozen reference copy)
dpo = DPO(model, beta=0.1)

optimizer = torch.optim.Adam(dpo.parameters(), lr=1e-5)

# Training loop
for preferred, unpreferred, prompt_mask in dataloader:
    loss = dpo(preferred, unpreferred, prompt_mask=prompt_mask)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Key details about the DPO wrapper:

  • DPO.__init__ creates a deep copy of the model as the frozen reference model.
  • The parameters() method returns only the policy model parameters (not the reference).
  • The beta parameter controls the strength of the KL divergence constraint (default: 0.1).
  • An optional pad_id can be specified to automatically generate sequence masks from padding tokens.

Input / Output

Direction Name Type Shape Description
Output preferred LongTensor (B, seq_len) Preferred sequences (prompt + preferred completion)
Output unpreferred LongTensor (B, seq_len) Unpreferred sequences (prompt + unpreferred completion)
Output prompt_mask BoolTensor (B, seq_len) True = prompt position (excluded from DPO loss)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment