Implementation:Lucidrains X transformers Preference Dataset Pattern
Appearance
| Field | Value |
|---|---|
| Repo | x-transformers |
| Domains | Data_Engineering, NLP, Alignment |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Pattern specification for creating preference pair datasets for DPO alignment training with x-transformers.
Description
This is a Pattern Doc. The interface is derived from DPO.forward() parameters at dpo.py:L71-117. The forward method signature is:
def forward(
self,
preferred_seq,
unpreferred_seq,
*,
prompt_mask,
preferred_seq_mask=None,
unpreferred_seq_mask=None,
):
The method asserts preferred_seq.ndim == 2 and preferred_seq.shape == unpreferred_seq.shape, requiring both sequences to be 2D tensors of identical shape.
Code Reference
File: x_transformers/dpo.py, Lines: L71-117 (DPO.forward())
Interface Specification
PreferenceDataset (Derived Interface)
class PreferenceDataset(Dataset):
"""Dataset for DPO training.
Returns (preferred_seq, unpreferred_seq, prompt_mask) tuples.
All tensors must have the same shape (seq_len,).
prompt_mask: True where tokens are prompt (excluded from loss).
"""
def __getitem__(self, index):
preferred = self.data[index]['preferred'] # (seq_len,) LongTensor
unpreferred = self.data[index]['unpreferred'] # (seq_len,) LongTensor
prompt_len = self.data[index]['prompt_len']
prompt_mask = torch.arange(len(preferred)) < prompt_len # True = prompt
return preferred, unpreferred, prompt_mask
Key details:
- Both
preferredandunpreferredmust be the same length (theDPO.forward()method asserts shape equality). - The
prompt_maskisTruefor prompt positions andFalsefor completion positions. The DPO loss is computed only whereprompt_maskisFalse. - If sequences have different lengths, they must be padded to the same length. Optional
preferred_seq_maskandunpreferred_seq_maskparameters can be passed toDPO.forward()to indicate padding, or thepad_idconstructor argument can be used for automatic mask generation.
Usage with DPO
from x_transformers import TransformerWrapper, Decoder
from x_transformers.dpo import DPO
# Create base model
model = TransformerWrapper(
num_tokens=NUM_TOKENS,
max_seq_len=SEQ_LEN,
attn_layers=Decoder(dim=512, depth=6, heads=8)
)
# Wrap with DPO (creates policy + frozen reference copy)
dpo = DPO(model, beta=0.1)
optimizer = torch.optim.Adam(dpo.parameters(), lr=1e-5)
# Training loop
for preferred, unpreferred, prompt_mask in dataloader:
loss = dpo(preferred, unpreferred, prompt_mask=prompt_mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Key details about the DPO wrapper:
DPO.__init__creates a deep copy of the model as the frozen reference model.- The
parameters()method returns only the policy model parameters (not the reference). - The
betaparameter controls the strength of the KL divergence constraint (default: 0.1). - An optional
pad_idcan be specified to automatically generate sequence masks from padding tokens.
Input / Output
| Direction | Name | Type | Shape | Description |
|---|---|---|---|---|
| Output | preferred |
LongTensor |
(B, seq_len) |
Preferred sequences (prompt + preferred completion) |
| Output | unpreferred |
LongTensor |
(B, seq_len) |
Unpreferred sequences (prompt + unpreferred completion) |
| Output | prompt_mask |
BoolTensor |
(B, seq_len) |
True = prompt position (excluded from DPO loss)
|
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment