Implementation:Lucidrains X transformers TransformerWrapper Decoder Init
Metadata
| Field | Value |
|---|---|
| Source | Repo: x-transformers |
| Domains | NLP, Model_Architecture |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for configuring causal decoder-only transformer models provided by the x-transformers library.
Description
TransformerWrapper is the main entry point in the x-transformers library for assembling a complete transformer model. It combines three concerns into a single torch.nn.Module:
- Token embeddings -- Maps integer token IDs to dense vectors of dimension
emb_dim(which defaults to the model dimensiondim). Supports optional L2 normalization of embeddings and fractional gradient scaling. - Positional encodings -- Adds absolute positional information to the token embeddings. By default, learned absolute positional embeddings are used. Alternatively, scaled sinusoidal embeddings can be enabled with
scaled_sinu_pos_emb=True. When the underlyingAttentionLayersuses rotary or ALiBi position encoding, absolute position embeddings are automatically disabled. - Attention layers -- The core transformer stack, provided as an
AttentionLayers(or subclass) instance via theattn_layersparameter. This is where the actual multi-head attention and feedforward computations occur. - Output projection -- A linear layer that projects the final hidden states back to vocabulary-sized logits. Supports weight tying with the token embedding layer via
tie_embedding=True.
When paired with Decoder -- which is simply AttentionLayers with causal=True hardcoded -- the resulting module is a GPT-style autoregressive language model that accepts token ID sequences and returns next-token logits.
The wrapper also handles a variety of optional features:
- Memory tokens (
num_memory_tokens) -- Learnable tokens prepended to the sequence, acting as persistent memory (similar to the approach in Memory Transformers). - Recycling (
recycling=True) -- Re-feeds the output embeddings back through the model for multiple passes, inspired by AlphaFold2. - Mixture of Softmax (
mixture_of_softmax=True) -- Replaces the single softmax output layer with a mixture ofksoftmax distributions for increased expressiveness. - CLS token pooling (
use_cls_token=True) -- Prepends a learnable CLS token for classification tasks. - Attention pooling (
attn_pool=True) -- Cross-attention-based pooling over the sequence for producing fixed-size representations. - Multiple output heads (
num_output_heads) -- Produces multiple independent logit heads from the same backbone. - Embedding dropout (
emb_dropout) -- Dropout applied after the embedding + positional encoding sum. - Deep feedforward embedding (
ff_deep_embed=True) -- Passes the token embeddings through a feedforward layer before entering the attention stack.
Usage
Import TransformerWrapper and Decoder when building an autoregressive decoder-only language model from scratch. Configure model size via:
num_tokens-- Vocabulary size (required).max_seq_len-- Maximum sequence length the model supports (required).attn_layers-- ADecoder(dim=..., depth=..., heads=...)instance that defines the transformer stack (required).
All other parameters are optional and provide fine-grained control over embeddings, output heads, and advanced features.
Code Reference
Repository
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/x_transformers.py
|
| Lines | L3266-3308 (TransformerWrapper.__init__), L3095-3098 (Decoder), L2226-2303 (AttentionLayers.__init__)
|
Import
from x_transformers import TransformerWrapper, Decoder
TransformerWrapper.__init__ Signature
class TransformerWrapper(Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers: AttentionLayers,
embed_num_tokens: dict[str, int] = dict(),
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
emb_dropout = 0.,
post_emb_norm = False,
num_memory_tokens = None,
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
return_only_embed = False,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
l2norm_embed = False,
recycling = False,
train_max_recycle_steps = 4,
emb_frac_gradient = 1.,
attn_z_loss_weight = 1e-4,
average_pool_embed = False,
use_cls_token = False,
num_cls_tokens = 1,
attn_pool = False,
num_pooled_tokens = 1,
attn_pool_depth = 1,
dim_pooled_tokens = None,
squeeze_out_last_dim = False,
token_emb: TokenEmbedding | None = None,
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
sigsoftmax_logits = False,
ff_deep_embed = False,
to_logits: Module | None = None,
add_continuous_pred_head = False,
input_not_include_cache = False
):
Decoder Class
class Decoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on decoder'
super().__init__(causal = True, **kwargs)
The Decoder class is a minimal subclass of AttentionLayers that enforces causal=True. It raises an assertion error if the user attempts to explicitly pass causal as a keyword argument, since the decoder is always causal by definition.
AttentionLayers.__init__ Key Parameters
The following are the primary parameters accepted by AttentionLayers.__init__ (and therefore by Decoder):
class AttentionLayers(Module):
def __init__(
self,
dim, # model dimension (required)
depth = None, # number of transformer layers
heads = 8, # number of attention heads
causal = False, # causal masking (set True by Decoder)
cross_attend = False, # enable cross-attention layers
only_cross = False, # use only cross-attention (no self-attention)
use_scalenorm = False, # use ScaleNorm instead of LayerNorm
use_rmsnorm = False, # use RMSNorm instead of LayerNorm
use_dynamic_tanh = False, # use dynamic tanh normalization
use_simple_rmsnorm = False, # use simplified RMSNorm
use_adaptive_layernorm = False, # adaptive layer norm (e.g., DiT-style)
use_adaptive_rmsnorm = False, # adaptive RMSNorm
use_adaptive_layerscale = False, # adaptive layer scale (ada-ln-zero)
norm_add_unit_offset = True, # add unit offset to norm
dim_condition = None, # dimension for conditioning signal
alibi_pos_bias = False, # use ALiBi positional bias
alibi_num_heads = None, # number of heads for ALiBi
rel_pos_bias = False, # use T5-style relative position bias
rel_pos_num_buckets = 32, # number of relative position buckets
rel_pos_max_distance = 128, # max distance for relative position
dynamic_pos_bias = False, # use dynamic positional bias
rotary_pos_emb = False, # use Rotary Position Embeddings (RoPE)
rotary_emb_dim = None, # dimension for rotary embeddings
rotary_xpos = False, # use xPos rotary variant
rotary_interpolation_factor = 1., # RoPE interpolation factor
rotary_xpos_scale_base = 512, # xPos scale base
rotary_base_rescale_factor = 1., # RoPE base rescale factor
weight_tie_layers = False, # tie weights across layers
custom_layers = None, # custom layer type specification
sandwich_coef = None, # sandwich transformer coefficient
par_ratio = None, # parallel layers ratio
residual_attn = False, # residual attention connections
cross_residual_attn = False, # cross-attention residual connections
macaron = False, # macaron-style FFN (half before, half after attn)
pre_norm = True, # pre-norm (True) vs post-norm (False)
pre_norm_has_final_norm = True, # add final norm after last layer
gate_residual = False, # gated residual connections
scale_residual = False, # scaled residual connections
scale_residual_constant = 1., # residual scaling constant
shift_tokens = 0, # number of tokens to shift for token shifting
sandwich_norm = False, # sandwich normalization
zero_init_branch_output = False, # zero-init output projections
layer_dropout = 0., # stochastic depth / layer dropout
cross_attn_tokens_dropout = 0., # dropout on cross-attention tokens
# ... additional advanced parameters
):
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
num_tokens |
int |
Yes | Size of the token vocabulary. Determines the embedding table size and the output logits dimension (unless overridden by logits_dim).
|
max_seq_len |
int |
Yes | Maximum sequence length the model can process. Determines the size of the absolute positional embedding table (if used). |
attn_layers |
AttentionLayers / Decoder |
Yes | The transformer layer stack. For causal language modeling, pass a Decoder(...) instance.
|
| all others | various | No | Optional configuration parameters for embeddings, output heads, memory, pooling, and advanced features (see full signature above). |
Outputs
The constructed TransformerWrapper instance is a torch.nn.Module. When called in forward mode:
| Input | Type | Description |
|---|---|---|
x |
torch.LongTensor of shape (batch, seq_len) |
Integer token IDs. |
| Output | Type | Description |
| logits | torch.FloatTensor of shape (batch, seq_len, num_tokens) |
Next-token prediction logits over the vocabulary at each position. |
Usage Examples
Basic Causal Decoder (enwik8 character-level model)
This example is derived from the train_enwik8.py training script included in the x-transformers repository:
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 256,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
rotary_pos_emb = True
)
)
This creates a character-level autoregressive model with:
- 256 tokens (byte-level vocabulary).
- 1024 maximum sequence length.
- 512-dimensional hidden states.
- 6 transformer layers.
- 8 attention heads (64 dimensions per head).
- Rotary position embeddings for relative position awareness.
Larger Model Configuration
A larger configuration suitable for subword-tokenized language modeling:
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 50304, # vocabulary size (padded to multiple of 64)
max_seq_len = 2048, # context window
tie_embedding = True, # tie input/output embeddings to save parameters
emb_dropout = 0.1, # embedding dropout
attn_layers = Decoder(
dim = 2048, # model dimension
depth = 24, # 24 transformer layers
heads = 16, # 16 attention heads (128 dim per head)
rotary_pos_emb = True, # rotary position embeddings
use_rmsnorm = True, # RMSNorm instead of LayerNorm
ff_mult = 4, # feedforward hidden dim = 4 * dim
attn_flash = True, # use Flash Attention for efficiency
)
)
This creates a model with approximately the scale of a 1B-parameter language model, using:
- RMSNorm for faster and more stable normalization (as in LLaMA).
- Tied embeddings to reduce parameter count.
- Flash Attention for memory-efficient and faster attention computation.
- Rotary position embeddings for length generalization.
Configuration with ALiBi Position Bias
from x_transformers import TransformerWrapper, Decoder
model = TransformerWrapper(
num_tokens = 32000,
max_seq_len = 4096,
use_abs_pos_emb = False, # disable absolute position embeddings
attn_layers = Decoder(
dim = 1024,
depth = 16,
heads = 16,
alibi_pos_bias = True, # use ALiBi instead of rotary
alibi_num_heads = 8, # ALiBi applied to 8 of 16 heads
)
)
When using ALiBi, absolute positional embeddings should be disabled since ALiBi encodes position information directly in the attention bias. The alibi_num_heads parameter allows applying ALiBi to a subset of heads, with the remaining heads using no positional encoding.