Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers FreeTransformer

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Variational_Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for variational autoregressive language modeling using binary discrete latent variables provided by the x-transformers library.

Description

The FreeTransformer class implements the FREE architecture from Fleuret (2025). It is a variational autoregressive language model that conditions text generation on binary discrete latent codes. The architecture consists of: (1) a decoder head that produces initial embeddings, (2) a cross-attention encoder that pools these into per-token latent representations, (3) a BinaryMapper that quantizes latents into binary codes via Bernoulli sampling with straight-through gradients, and (4) a decoder tail that generates text conditioned on the quantized latents. The binary latent space enables controllable generation by specifying latent codes at inference time.

Usage

Import this class when you need a language model with a discrete latent space for controllable or steerable text generation. The binary codes can be used to condition generation on specific styles, topics, or behaviors. Useful for diversity-promoting RLHF, MAP-Elites, and steerable generation.

Code Reference

Source Location

Signature

class FreeTransformer(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        dec_head_depth,
        dec_tail_depth,
        max_seq_len,
        enc_depth = 1,
        dim_latent = None,
        attn_dim_head = 64,
        heads = 8,
        latent_bits = 16,
        per_token_latents = True,
        kl_loss_threshold = NAT,
        binary_mapper_kwargs: dict = dict(),
        enc_kwargs: dict = dict(),
        dec_kwargs: dict = dict(),
        kl_loss_weight = 1.,
        latent_dropout_prob = 0.,
        pad_id = -1,
        **kwargs
    ):
        """
        Args:
            num_tokens: Vocabulary size.
            dim: Model dimension.
            dec_head_depth: Number of decoder head layers (before latent conditioning).
            dec_tail_depth: Number of decoder tail layers (after latent conditioning).
            max_seq_len: Maximum sequence length.
            enc_depth: Encoder depth for latent pooling (default 1).
            dim_latent: Latent dimension (defaults to dim).
            latent_bits: Number of binary bits per latent (default 16, giving 2^16 codes).
            per_token_latents: Whether to use per-token latents vs one for entire sequence.
            kl_loss_threshold: KL divergence floor for binary entropy loss.
            kl_loss_weight: Weight for KL auxiliary loss.
            latent_dropout_prob: Probability of dropping out latent conditioning.
            pad_id: Padding token ID for ignore in loss.
        """

Import

from x_transformers.free_transformer import FreeTransformer

I/O Contract

Inputs

Name Type Required Description
seq Tensor (b, n) Yes Input token sequence
seq_for_latents Tensor (b, m) No Separate sequence for encoding latents (defaults to seq)
return_all_losses bool No Return breakdown of AR loss and KL loss

Outputs

Name Type Description
forward() returns Tensor (scalar) Total loss (AR loss + KL loss * weight)
forward() with return_all_losses (Tensor, (Tensor, Tensor)) Total loss, plus (ar_loss, kl_loss) breakdown
generate() returns Tensor (b, n) Generated token sequence conditioned on optional latent codes

Usage Examples

Training

import torch
from x_transformers.free_transformer import FreeTransformer

model = FreeTransformer(
    num_tokens=256,
    dim=256,
    dec_head_depth=3,
    dec_tail_depth=3,
    max_seq_len=512,
    latent_bits=16,
    kl_loss_weight=1.0,
    latent_dropout_prob=0.1
)

seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()

Latent-Conditioned Generation

# Generate conditioned on a specific latent code index
prompt = torch.randint(0, 256, (1, 5))
generated = model.generate(
    prompts=prompt,
    seq_len=100,
    latents=42  # integer index into 2^bits code space
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment