Implementation:Lucidrains X transformers DPO Init
Appearance
Implementation: DPO_Init
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (x-transformers) |
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for initializing Direct Preference Optimization with policy and reference models provided by the x-transformers library.
Description
The DPO class takes a pretrained TransformerWrapper, stores it as self.policy_model, deep-copies it to create self.ref_model (frozen, all parameters set to requires_grad=False), and stores the DPO temperature beta.
Key implementation details:
- The deep copy is performed using Python's
copy.deepcopy, ensuring the reference model is a completely independent copy of the original weights. - The helper function
freeze_all_layers_()iterates over all parameters in the reference model and setsparam.requires_grad = False. - The
.parameters()method is overridden to return only the policy model's parameters, so any optimizer created fromdpo.parameters()will only update the trainable policy. - The optional
pad_idparameter enables automatic padding mask creation during the forward pass: any token equal topad_idwill be excluded from the loss computation.
Code Reference
Source Location
x-transformers repo, file: x_transformers/dpo.py, lines L51-66.
Signature
class DPO(Module):
def __init__(
self,
model: TransformerWrapper,
*,
beta = 0.1,
pad_id = None
):
Import
from x_transformers.dpo import DPO
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
model |
TransformerWrapper | Yes | -- | A pretrained policy model. This TransformerWrapper instance will be stored as the trainable policy and deep-copied to create the frozen reference. |
beta |
float | No | 0.1 |
DPO temperature parameter. Controls how much the policy is allowed to deviate from the reference distribution. Higher values allow greater divergence. |
pad_id |
int or None | No | None |
When set, automatically creates padding masks during the forward pass by comparing token values to this ID. Tokens equal to pad_id are excluded from the loss.
|
Constructor Outputs
| Output | Type | Description |
|---|---|---|
| instance | DPO | A DPO wrapper instance with .policy_model (trainable TransformerWrapper) and .ref_model (frozen TransformerWrapper). The .parameters() method returns only the policy model's parameters.
|
Usage Examples
Basic DPO Initialization
from x_transformers import TransformerWrapper, Decoder
from x_transformers.dpo import DPO
# Pretrained base model
base_model = TransformerWrapper(
num_tokens = 256,
max_seq_len = 512,
attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
).cuda()
# ... pretrain base_model ...
# Initialize DPO
dpo = DPO(
base_model,
beta = 0.1
).cuda()
# Only policy_model parameters are trainable
optimizer = torch.optim.Adam(dpo.parameters(), lr=1e-6)
With Automatic Padding Masks
dpo = DPO(
base_model,
beta = 0.1,
pad_id = 0 # token 0 is padding
).cuda()
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment