Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers TransformerWrapper Decoder Init

From Leeroopedia


Metadata

Field Value
Source Repo: x-transformers
Domains NLP, Model_Architecture
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for configuring causal decoder-only transformer models provided by the x-transformers library.

Description

TransformerWrapper is the main entry point in the x-transformers library for assembling a complete transformer model. It combines three concerns into a single torch.nn.Module:

  1. Token embeddings -- Maps integer token IDs to dense vectors of dimension emb_dim (which defaults to the model dimension dim). Supports optional L2 normalization of embeddings and fractional gradient scaling.
  2. Positional encodings -- Adds absolute positional information to the token embeddings. By default, learned absolute positional embeddings are used. Alternatively, scaled sinusoidal embeddings can be enabled with scaled_sinu_pos_emb=True. When the underlying AttentionLayers uses rotary or ALiBi position encoding, absolute position embeddings are automatically disabled.
  3. Attention layers -- The core transformer stack, provided as an AttentionLayers (or subclass) instance via the attn_layers parameter. This is where the actual multi-head attention and feedforward computations occur.
  4. Output projection -- A linear layer that projects the final hidden states back to vocabulary-sized logits. Supports weight tying with the token embedding layer via tie_embedding=True.

When paired with Decoder -- which is simply AttentionLayers with causal=True hardcoded -- the resulting module is a GPT-style autoregressive language model that accepts token ID sequences and returns next-token logits.

The wrapper also handles a variety of optional features:

  • Memory tokens (num_memory_tokens) -- Learnable tokens prepended to the sequence, acting as persistent memory (similar to the approach in Memory Transformers).
  • Recycling (recycling=True) -- Re-feeds the output embeddings back through the model for multiple passes, inspired by AlphaFold2.
  • Mixture of Softmax (mixture_of_softmax=True) -- Replaces the single softmax output layer with a mixture of k softmax distributions for increased expressiveness.
  • CLS token pooling (use_cls_token=True) -- Prepends a learnable CLS token for classification tasks.
  • Attention pooling (attn_pool=True) -- Cross-attention-based pooling over the sequence for producing fixed-size representations.
  • Multiple output heads (num_output_heads) -- Produces multiple independent logit heads from the same backbone.
  • Embedding dropout (emb_dropout) -- Dropout applied after the embedding + positional encoding sum.
  • Deep feedforward embedding (ff_deep_embed=True) -- Passes the token embeddings through a feedforward layer before entering the attention stack.

Usage

Import TransformerWrapper and Decoder when building an autoregressive decoder-only language model from scratch. Configure model size via:

  • num_tokens -- Vocabulary size (required).
  • max_seq_len -- Maximum sequence length the model supports (required).
  • attn_layers -- A Decoder(dim=..., depth=..., heads=...) instance that defines the transformer stack (required).

All other parameters are optional and provide fine-grained control over embeddings, output heads, and advanced features.

Code Reference

Repository

Field Value
Repository x-transformers
File x_transformers/x_transformers.py
Lines L3266-3308 (TransformerWrapper.__init__), L3095-3098 (Decoder), L2226-2303 (AttentionLayers.__init__)

Import

from x_transformers import TransformerWrapper, Decoder

TransformerWrapper.__init__ Signature

class TransformerWrapper(Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        attn_layers: AttentionLayers,
        embed_num_tokens: dict[str, int] = dict(),
        emb_dim = None,
        max_mem_len = 0,
        shift_mem_down = 0,
        emb_dropout = 0.,
        post_emb_norm = False,
        num_memory_tokens = None,
        memory_tokens_interspersed_every = None,
        tie_embedding = False,
        logits_dim = None,
        return_only_embed = False,
        num_output_heads = 1,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False,
        l2norm_embed = False,
        recycling = False,
        train_max_recycle_steps = 4,
        emb_frac_gradient = 1.,
        attn_z_loss_weight = 1e-4,
        average_pool_embed = False,
        use_cls_token = False,
        num_cls_tokens = 1,
        attn_pool = False,
        num_pooled_tokens = 1,
        attn_pool_depth = 1,
        dim_pooled_tokens = None,
        squeeze_out_last_dim = False,
        token_emb: TokenEmbedding | None = None,
        mixture_of_softmax = False,
        mixture_of_softmax_k = 4,
        sigsoftmax_logits = False,
        ff_deep_embed = False,
        to_logits: Module | None = None,
        add_continuous_pred_head = False,
        input_not_include_cache = False
    ):

Decoder Class

class Decoder(AttentionLayers):
    def __init__(self, **kwargs):
        assert 'causal' not in kwargs, 'cannot set causality on decoder'
        super().__init__(causal = True, **kwargs)

The Decoder class is a minimal subclass of AttentionLayers that enforces causal=True. It raises an assertion error if the user attempts to explicitly pass causal as a keyword argument, since the decoder is always causal by definition.

AttentionLayers.__init__ Key Parameters

The following are the primary parameters accepted by AttentionLayers.__init__ (and therefore by Decoder):

class AttentionLayers(Module):
    def __init__(
        self,
        dim,                                    # model dimension (required)
        depth = None,                           # number of transformer layers
        heads = 8,                              # number of attention heads
        causal = False,                         # causal masking (set True by Decoder)
        cross_attend = False,                   # enable cross-attention layers
        only_cross = False,                     # use only cross-attention (no self-attention)
        use_scalenorm = False,                  # use ScaleNorm instead of LayerNorm
        use_rmsnorm = False,                    # use RMSNorm instead of LayerNorm
        use_dynamic_tanh = False,               # use dynamic tanh normalization
        use_simple_rmsnorm = False,             # use simplified RMSNorm
        use_adaptive_layernorm = False,         # adaptive layer norm (e.g., DiT-style)
        use_adaptive_rmsnorm = False,           # adaptive RMSNorm
        use_adaptive_layerscale = False,        # adaptive layer scale (ada-ln-zero)
        norm_add_unit_offset = True,            # add unit offset to norm
        dim_condition = None,                   # dimension for conditioning signal
        alibi_pos_bias = False,                 # use ALiBi positional bias
        alibi_num_heads = None,                 # number of heads for ALiBi
        rel_pos_bias = False,                   # use T5-style relative position bias
        rel_pos_num_buckets = 32,               # number of relative position buckets
        rel_pos_max_distance = 128,             # max distance for relative position
        dynamic_pos_bias = False,               # use dynamic positional bias
        rotary_pos_emb = False,                 # use Rotary Position Embeddings (RoPE)
        rotary_emb_dim = None,                  # dimension for rotary embeddings
        rotary_xpos = False,                    # use xPos rotary variant
        rotary_interpolation_factor = 1.,       # RoPE interpolation factor
        rotary_xpos_scale_base = 512,           # xPos scale base
        rotary_base_rescale_factor = 1.,        # RoPE base rescale factor
        weight_tie_layers = False,              # tie weights across layers
        custom_layers = None,                   # custom layer type specification
        sandwich_coef = None,                   # sandwich transformer coefficient
        par_ratio = None,                       # parallel layers ratio
        residual_attn = False,                  # residual attention connections
        cross_residual_attn = False,            # cross-attention residual connections
        macaron = False,                        # macaron-style FFN (half before, half after attn)
        pre_norm = True,                        # pre-norm (True) vs post-norm (False)
        pre_norm_has_final_norm = True,         # add final norm after last layer
        gate_residual = False,                  # gated residual connections
        scale_residual = False,                 # scaled residual connections
        scale_residual_constant = 1.,           # residual scaling constant
        shift_tokens = 0,                       # number of tokens to shift for token shifting
        sandwich_norm = False,                  # sandwich normalization
        zero_init_branch_output = False,        # zero-init output projections
        layer_dropout = 0.,                     # stochastic depth / layer dropout
        cross_attn_tokens_dropout = 0.,         # dropout on cross-attention tokens
        # ... additional advanced parameters
    ):

I/O Contract

Inputs

Parameter Type Required Description
num_tokens int Yes Size of the token vocabulary. Determines the embedding table size and the output logits dimension (unless overridden by logits_dim).
max_seq_len int Yes Maximum sequence length the model can process. Determines the size of the absolute positional embedding table (if used).
attn_layers AttentionLayers / Decoder Yes The transformer layer stack. For causal language modeling, pass a Decoder(...) instance.
all others various No Optional configuration parameters for embeddings, output heads, memory, pooling, and advanced features (see full signature above).

Outputs

The constructed TransformerWrapper instance is a torch.nn.Module. When called in forward mode:

Input Type Description
x torch.LongTensor of shape (batch, seq_len) Integer token IDs.
Output Type Description
logits torch.FloatTensor of shape (batch, seq_len, num_tokens) Next-token prediction logits over the vocabulary at each position.

Usage Examples

Basic Causal Decoder (enwik8 character-level model)

This example is derived from the train_enwik8.py training script included in the x-transformers repository:

from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rotary_pos_emb = True
    )
)

This creates a character-level autoregressive model with:

  • 256 tokens (byte-level vocabulary).
  • 1024 maximum sequence length.
  • 512-dimensional hidden states.
  • 6 transformer layers.
  • 8 attention heads (64 dimensions per head).
  • Rotary position embeddings for relative position awareness.

Larger Model Configuration

A larger configuration suitable for subword-tokenized language modeling:

from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 50304,          # vocabulary size (padded to multiple of 64)
    max_seq_len = 2048,          # context window
    tie_embedding = True,        # tie input/output embeddings to save parameters
    emb_dropout = 0.1,           # embedding dropout
    attn_layers = Decoder(
        dim = 2048,              # model dimension
        depth = 24,              # 24 transformer layers
        heads = 16,              # 16 attention heads (128 dim per head)
        rotary_pos_emb = True,   # rotary position embeddings
        use_rmsnorm = True,      # RMSNorm instead of LayerNorm
        ff_mult = 4,             # feedforward hidden dim = 4 * dim
        attn_flash = True,       # use Flash Attention for efficiency
    )
)

This creates a model with approximately the scale of a 1B-parameter language model, using:

  • RMSNorm for faster and more stable normalization (as in LLaMA).
  • Tied embeddings to reduce parameter count.
  • Flash Attention for memory-efficient and faster attention computation.
  • Rotary position embeddings for length generalization.

Configuration with ALiBi Position Bias

from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
    num_tokens = 32000,
    max_seq_len = 4096,
    use_abs_pos_emb = False,     # disable absolute position embeddings
    attn_layers = Decoder(
        dim = 1024,
        depth = 16,
        heads = 16,
        alibi_pos_bias = True,   # use ALiBi instead of rotary
        alibi_num_heads = 8,     # ALiBi applied to 8 of 16 heads
    )
)

When using ALiBi, absolute positional embeddings should be disabled since ALiBi encodes position information directly in the attention bias. The alibi_num_heads parameter allows applying ALiBi to a subset of heads, with the remaining heads using no positional encoding.

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment