Implementation:Lucidrains X transformers FreeTransformer
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Variational_Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for variational autoregressive language modeling using binary discrete latent variables provided by the x-transformers library.
Description
The FreeTransformer class implements the FREE architecture from Fleuret (2025). It is a variational autoregressive language model that conditions text generation on binary discrete latent codes. The architecture consists of: (1) a decoder head that produces initial embeddings, (2) a cross-attention encoder that pools these into per-token latent representations, (3) a BinaryMapper that quantizes latents into binary codes via Bernoulli sampling with straight-through gradients, and (4) a decoder tail that generates text conditioned on the quantized latents. The binary latent space enables controllable generation by specifying latent codes at inference time.
Usage
Import this class when you need a language model with a discrete latent space for controllable or steerable text generation. The binary codes can be used to condition generation on specific styles, topics, or behaviors. Useful for diversity-promoting RLHF, MAP-Elites, and steerable generation.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/free_transformer.py
- Lines: 132-416
Signature
class FreeTransformer(Module):
def __init__(
self,
*,
num_tokens,
dim,
dec_head_depth,
dec_tail_depth,
max_seq_len,
enc_depth = 1,
dim_latent = None,
attn_dim_head = 64,
heads = 8,
latent_bits = 16,
per_token_latents = True,
kl_loss_threshold = NAT,
binary_mapper_kwargs: dict = dict(),
enc_kwargs: dict = dict(),
dec_kwargs: dict = dict(),
kl_loss_weight = 1.,
latent_dropout_prob = 0.,
pad_id = -1,
**kwargs
):
"""
Args:
num_tokens: Vocabulary size.
dim: Model dimension.
dec_head_depth: Number of decoder head layers (before latent conditioning).
dec_tail_depth: Number of decoder tail layers (after latent conditioning).
max_seq_len: Maximum sequence length.
enc_depth: Encoder depth for latent pooling (default 1).
dim_latent: Latent dimension (defaults to dim).
latent_bits: Number of binary bits per latent (default 16, giving 2^16 codes).
per_token_latents: Whether to use per-token latents vs one for entire sequence.
kl_loss_threshold: KL divergence floor for binary entropy loss.
kl_loss_weight: Weight for KL auxiliary loss.
latent_dropout_prob: Probability of dropping out latent conditioning.
pad_id: Padding token ID for ignore in loss.
"""
Import
from x_transformers.free_transformer import FreeTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| seq | Tensor (b, n) | Yes | Input token sequence |
| seq_for_latents | Tensor (b, m) | No | Separate sequence for encoding latents (defaults to seq) |
| return_all_losses | bool | No | Return breakdown of AR loss and KL loss |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | Tensor (scalar) | Total loss (AR loss + KL loss * weight) |
| forward() with return_all_losses | (Tensor, (Tensor, Tensor)) | Total loss, plus (ar_loss, kl_loss) breakdown |
| generate() returns | Tensor (b, n) | Generated token sequence conditioned on optional latent codes |
Usage Examples
Training
import torch
from x_transformers.free_transformer import FreeTransformer
model = FreeTransformer(
num_tokens=256,
dim=256,
dec_head_depth=3,
dec_tail_depth=3,
max_seq_len=512,
latent_bits=16,
kl_loss_weight=1.0,
latent_dropout_prob=0.1
)
seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()
Latent-Conditioned Generation
# Generate conditioned on a specific latent code index
prompt = torch.randint(0, 256, (1, 5))
generated = model.generate(
prompts=prompt,
seq_len=100,
latents=42 # integer index into 2^bits code space
)