Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers GPTVAE

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Variational_Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for combining a conditional variational autoencoder (CVAE) with an autoregressive GPT decoder for controllable text generation provided by the x-transformers library.

Description

The GPTVAE class implements a GPT-VAE architecture inspired by the CVAE + DETR design from ACT (Zhou et al.), adapted for language modeling. It uses an encoder (TransformerWrapper with average pooling) to compress the input sequence into a Gaussian latent space via the reparameterization trick. The latent vector is then projected into a conditioning token that is prepended to the decoder input. The decoder, wrapped in an AutoregressiveWrapper, generates text conditioned on this latent. Training uses a combined loss of autoregressive cross-entropy and KL divergence with an optional floor (from Kingma 2016 / FREE Transformer) to prevent posterior collapse. Latent dropout is applied during training to ensure the model does not over-rely on the latent signal.

Usage

Import this class when you need a GPT-style language model with a continuous latent space for controllable generation. The latent vectors can be used to steer generation toward particular styles, topics, or behaviors, and can be sampled or interpolated for diverse outputs.

Code Reference

Source Location

Signature

class GPTVAE(Module):
    def __init__(
        self,
        *,
        num_tokens,
        dim,
        depth,
        enc_depth,
        max_seq_len,
        dim_latent = None,
        attn_dim_head = 64,
        heads = 8,
        enc_kwargs: dict = dict(),
        dec_kwargs: dict = dict(),
        vae_kl_loss_weight = 1.,
        vae_kl_div_floor = 0.,
        latents_dropout_prob = 0.5,
        pad_id = -1,
        encoder: Module | None = None,
        **kwargs
    ):
        """
        Args:
            num_tokens: Vocabulary size.
            dim: Model dimension.
            depth: Decoder depth.
            enc_depth: Encoder depth.
            max_seq_len: Maximum sequence length.
            dim_latent: Latent dimension (defaults to dim).
            attn_dim_head: Attention head dimension.
            heads: Number of attention heads.
            vae_kl_loss_weight: Weight for KL divergence loss.
            vae_kl_div_floor: KL divergence floor (Kingma 2016).
            latents_dropout_prob: Probability of dropping latents entirely during training.
            pad_id: Token ID to ignore in loss computation.
            encoder: Optional custom encoder module.
        """

Import

from x_transformers.gpt_vae import GPTVAE

I/O Contract

Inputs

Name Type Required Description
seq Tensor (b, n) Yes Input token sequence
seq_for_latents Tensor (b, m) No Separate sequence for encoding latents (defaults to seq)
return_all_losses bool No Return breakdown of AR and KL losses

Outputs

Name Type Description
forward() returns Tensor (scalar) Total loss (AR loss + KL loss * weight)
forward() with return_all_losses (Tensor, (Tensor, Tensor)) Total loss, plus (ar_loss, vae_kl_loss)
generate() returns Tensor (b, n) Generated token sequence conditioned on latent
encode_to_latents() returns Tensor (b, d) Sampled latent vector (optionally with mean and log_var)

Usage Examples

Training

import torch
from x_transformers.gpt_vae import GPTVAE

model = GPTVAE(
    num_tokens=256,
    dim=256,
    depth=6,
    enc_depth=2,
    max_seq_len=512,
    vae_kl_loss_weight=1.0,
    vae_kl_div_floor=0.1,
    latents_dropout_prob=0.5
)

seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()

Latent-Conditioned Generation

# Generate conditioned on a specific latent vector
prompt = torch.randint(0, 256, (1, 5))
latent = torch.randn(1, 256)  # sample or specify a latent

generated = model.generate(prompt, seq_len=100, latents=latent)

Encode-then-Generate

# Encode a reference sequence, then generate from its latent
reference = torch.randint(0, 256, (1, 100))
prompt = torch.randint(0, 256, (1, 5))

generated = model.generate(prompt, seq_len=100, seq_for_latents=reference)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment