Implementation:Lucidrains X transformers GPTVAE
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Variational_Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for combining a conditional variational autoencoder (CVAE) with an autoregressive GPT decoder for controllable text generation provided by the x-transformers library.
Description
The GPTVAE class implements a GPT-VAE architecture inspired by the CVAE + DETR design from ACT (Zhou et al.), adapted for language modeling. It uses an encoder (TransformerWrapper with average pooling) to compress the input sequence into a Gaussian latent space via the reparameterization trick. The latent vector is then projected into a conditioning token that is prepended to the decoder input. The decoder, wrapped in an AutoregressiveWrapper, generates text conditioned on this latent. Training uses a combined loss of autoregressive cross-entropy and KL divergence with an optional floor (from Kingma 2016 / FREE Transformer) to prevent posterior collapse. Latent dropout is applied during training to ensure the model does not over-rely on the latent signal.
Usage
Import this class when you need a GPT-style language model with a continuous latent space for controllable generation. The latent vectors can be used to steer generation toward particular styles, topics, or behaviors, and can be sampled or interpolated for diverse outputs.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/gpt_vae.py
- Lines: 32-221
Signature
class GPTVAE(Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
enc_depth,
max_seq_len,
dim_latent = None,
attn_dim_head = 64,
heads = 8,
enc_kwargs: dict = dict(),
dec_kwargs: dict = dict(),
vae_kl_loss_weight = 1.,
vae_kl_div_floor = 0.,
latents_dropout_prob = 0.5,
pad_id = -1,
encoder: Module | None = None,
**kwargs
):
"""
Args:
num_tokens: Vocabulary size.
dim: Model dimension.
depth: Decoder depth.
enc_depth: Encoder depth.
max_seq_len: Maximum sequence length.
dim_latent: Latent dimension (defaults to dim).
attn_dim_head: Attention head dimension.
heads: Number of attention heads.
vae_kl_loss_weight: Weight for KL divergence loss.
vae_kl_div_floor: KL divergence floor (Kingma 2016).
latents_dropout_prob: Probability of dropping latents entirely during training.
pad_id: Token ID to ignore in loss computation.
encoder: Optional custom encoder module.
"""
Import
from x_transformers.gpt_vae import GPTVAE
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| seq | Tensor (b, n) | Yes | Input token sequence |
| seq_for_latents | Tensor (b, m) | No | Separate sequence for encoding latents (defaults to seq) |
| return_all_losses | bool | No | Return breakdown of AR and KL losses |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | Tensor (scalar) | Total loss (AR loss + KL loss * weight) |
| forward() with return_all_losses | (Tensor, (Tensor, Tensor)) | Total loss, plus (ar_loss, vae_kl_loss) |
| generate() returns | Tensor (b, n) | Generated token sequence conditioned on latent |
| encode_to_latents() returns | Tensor (b, d) | Sampled latent vector (optionally with mean and log_var) |
Usage Examples
Training
import torch
from x_transformers.gpt_vae import GPTVAE
model = GPTVAE(
num_tokens=256,
dim=256,
depth=6,
enc_depth=2,
max_seq_len=512,
vae_kl_loss_weight=1.0,
vae_kl_div_floor=0.1,
latents_dropout_prob=0.5
)
seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()
Latent-Conditioned Generation
# Generate conditioned on a specific latent vector
prompt = torch.randint(0, 256, (1, 5))
latent = torch.randn(1, 256) # sample or specify a latent
generated = model.generate(prompt, seq_len=100, latents=latent)
Encode-then-Generate
# Encode a reference sequence, then generate from its latent
reference = torch.randint(0, 256, (1, 100))
prompt = torch.randint(0, 256, (1, 5))
generated = model.generate(prompt, seq_len=100, seq_for_latents=reference)