Implementation:Lucidrains X transformers DPO Policy Model Evaluation
Implementation: DPO_Policy_Model_Evaluation
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Wrapper Doc) |
| Knowledge Sources | Repo (x-transformers) |
| Domains | NLP, Alignment, Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for extracting and using DPO-aligned policy models for text generation provided by the x-transformers library.
Description
This is a Wrapper Doc documenting the pattern of extracting the policy model from a trained DPO instance and wrapping it with AutoregressiveWrapper for generation.
The dpo.policy_model attribute is the aligned TransformerWrapper whose weights have been updated by DPO training. Since the DPO class trains the raw transformer (not an autoregressive wrapper), the model must be explicitly wrapped with AutoregressiveWrapper to access the .generate() method for text generation.
Key details:
- Policy model access:
dpo.policy_modelreturns the trainable TransformerWrapper instance. - Reference model access:
dpo.ref_modelreturns the frozen TransformerWrapper instance (useful for comparison). - Generation wrapping: Pass the policy model to AutoregressiveWrapper to gain access to all generation strategies (top-k, top-p, min-p, beam search, contrastive decoding).
- No retraining needed: The wrapper simply provides a generation interface around the already-trained model.
Code Reference
Source Locations
| Component | File | Lines | Description |
|---|---|---|---|
policy_model attribute |
x_transformers/dpo.py |
L60 | The aligned TransformerWrapper stored during DPO initialization. |
| AutoregressiveWrapper | x_transformers/autoregressive_wrapper.py |
L156-183 | Wrapper class that provides generation capabilities. |
.generate() method |
x_transformers/autoregressive_wrapper.py |
L351-509 | Autoregressive generation with multiple sampling strategies. |
Pattern
# Access aligned model
aligned_model = dpo.policy_model
# Wrap for generation
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
gen_model = AutoregressiveWrapper(aligned_model)
# Generate
output = gen_model.generate(prompt, seq_len=256)
Imports
from x_transformers.dpo import DPO
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
I/O Contract
Inputs
| Input | Type | Required | Description |
|---|---|---|---|
| Trained DPO instance | DPO | Yes | A DPO instance that has been trained on preference pairs. The .policy_model attribute contains the aligned TransformerWrapper.
|
| Prompt tokens | Tensor (B, prompt_len) | Yes | Tokenized prompt sequences to generate continuations from. |
Outputs
| Output | Type | Description |
|---|---|---|
| Aligned model | TransformerWrapper | The DPO-aligned TransformerWrapper extracted via dpo.policy_model.
|
| Generated text | Tensor (B, prompt_len + seq_len) | Token sequences produced by AutoregressiveWrapper.generate(), containing the prompt followed by the generated continuation.
|
Usage Examples
Basic Generation from Aligned Model
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import torch
# After DPO training
aligned_model = dpo.policy_model
gen_model = AutoregressiveWrapper(aligned_model)
# Generate from aligned model
prompt = torch.tensor([[1, 5, 10, 22]]).cuda()
output = gen_model.generate(
prompts = prompt,
seq_len = 256,
temperature = 0.7,
filter_logits_fn = 'top_p',
filter_kwargs = dict(thres = 0.9)
)
Comparing Aligned vs Reference Model Outputs
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import torch
# Wrap both models for generation
aligned_gen = AutoregressiveWrapper(dpo.policy_model)
reference_gen = AutoregressiveWrapper(dpo.ref_model)
prompt = torch.tensor([[1, 5, 10, 22]]).cuda()
# Generate from both
aligned_output = aligned_gen.generate(prompts=prompt, seq_len=256, temperature=0.7)
reference_output = reference_gen.generate(prompts=prompt, seq_len=256, temperature=0.7)
# Compare outputs to evaluate alignment quality