Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers DPO Init

From Leeroopedia


Implementation: DPO_Init

Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (x-transformers)
Domains NLP, Alignment
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for initializing Direct Preference Optimization with policy and reference models provided by the x-transformers library.

Description

The DPO class takes a pretrained TransformerWrapper, stores it as self.policy_model, deep-copies it to create self.ref_model (frozen, all parameters set to requires_grad=False), and stores the DPO temperature beta.

Key implementation details:

  • The deep copy is performed using Python's copy.deepcopy, ensuring the reference model is a completely independent copy of the original weights.
  • The helper function freeze_all_layers_() iterates over all parameters in the reference model and sets param.requires_grad = False.
  • The .parameters() method is overridden to return only the policy model's parameters, so any optimizer created from dpo.parameters() will only update the trainable policy.
  • The optional pad_id parameter enables automatic padding mask creation during the forward pass: any token equal to pad_id will be excluded from the loss computation.

Code Reference

Source Location

x-transformers repo, file: x_transformers/dpo.py, lines L51-66.

Signature

class DPO(Module):
    def __init__(
        self,
        model: TransformerWrapper,
        *,
        beta = 0.1,
        pad_id = None
    ):

Import

from x_transformers.dpo import DPO

I/O Contract

Constructor Inputs

Parameter Type Required Default Description
model TransformerWrapper Yes -- A pretrained policy model. This TransformerWrapper instance will be stored as the trainable policy and deep-copied to create the frozen reference.
beta float No 0.1 DPO temperature parameter. Controls how much the policy is allowed to deviate from the reference distribution. Higher values allow greater divergence.
pad_id int or None No None When set, automatically creates padding masks during the forward pass by comparing token values to this ID. Tokens equal to pad_id are excluded from the loss.

Constructor Outputs

Output Type Description
instance DPO A DPO wrapper instance with .policy_model (trainable TransformerWrapper) and .ref_model (frozen TransformerWrapper). The .parameters() method returns only the policy model's parameters.

Usage Examples

Basic DPO Initialization

from x_transformers import TransformerWrapper, Decoder
from x_transformers.dpo import DPO

# Pretrained base model
base_model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 512,
    attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
).cuda()

# ... pretrain base_model ...

# Initialize DPO
dpo = DPO(
    base_model,
    beta = 0.1
).cuda()

# Only policy_model parameters are trainable
optimizer = torch.optim.Adam(dpo.parameters(), lr=1e-6)

With Automatic Padding Masks

dpo = DPO(
    base_model,
    beta = 0.1,
    pad_id = 0  # token 0 is padding
).cuda()

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment