Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers TransformerWrapper Encoder Init

From Leeroopedia


Metadata

Field Value
Source Repo: x-transformers
Domains NLP, Model_Architecture
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for configuring bidirectional encoder transformer models for masked token prediction provided by the x-transformers library.

Description

TransformerWrapper with an Encoder attention layer produces a bidirectional transformer model suitable for masked token prediction and non-autoregressive generation. This uses the same TransformerWrapper class as the decoder configuration but with Encoder (which sets causal=False) as the attention layers.

The key differences from the causal decoder configuration are:

  • No causal mask -- The Encoder class forces causal=False, so every position attends to every other position in the sequence.
  • Vocabulary size must include the mask token -- Since the mask token is fed as input to the model, num_tokens must be large enough to include the mask token ID. For example, if the base vocabulary has 256 tokens and the mask token ID is 256, then num_tokens must be at least 257.
  • max_seq_len sets the fixed generation length -- When used with NonAutoregressiveWrapper, the max_seq_len parameter defines the fixed number of tokens produced during generation.

The model takes token sequences (with mask tokens at positions to be predicted) and outputs logits over the vocabulary for each position. During training, the loss is computed only at masked positions. During inference, the model iteratively predicts and unmasks tokens in a MaskGIT-style refinement loop.

The wrapper also supports all the same optional features as the decoder configuration, including memory tokens, embedding dropout, tied embeddings, and advanced features such as recycling and mixture of softmax.

Usage

Import TransformerWrapper and Encoder when building a non-autoregressive masked prediction model. Configure the model via:

  • num_tokens -- Vocabulary size, which must include the mask token (required).
  • max_seq_len -- Fixed sequence length for generation (required).
  • attn_layers -- An Encoder(dim=..., depth=..., heads=...) instance that defines the bidirectional transformer stack (required).

All other parameters are optional and provide fine-grained control over embeddings, output heads, and advanced features.

Code Reference

Repository

Field Value
Repository x-transformers
File x_transformers/x_transformers.py
Lines L3266-3308 (TransformerWrapper.__init__), L3090-3093 (Encoder)

Import

from x_transformers import TransformerWrapper, Encoder

Encoder Class

class Encoder(AttentionLayers):
    def __init__(self, **kwargs):
        assert 'causal' not in kwargs, 'cannot set causality on encoder'
        super().__init__(causal = False, **kwargs)

The Encoder class is a minimal subclass of AttentionLayers that enforces causal=False. It raises an assertion error if the user attempts to explicitly pass causal as a keyword argument, since the encoder is always non-causal by definition. This is the mirror image of the Decoder class, which enforces causal=True.

TransformerWrapper.__init__ Signature

The TransformerWrapper signature is identical to the decoder configuration (see L3266-3308). The same class is used for both encoder and decoder models; the difference lies entirely in the attn_layers argument:

class TransformerWrapper(Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        attn_layers: AttentionLayers,
        embed_num_tokens: dict[str, int] = dict(),
        emb_dim = None,
        max_mem_len = 0,
        shift_mem_down = 0,
        emb_dropout = 0.,
        post_emb_norm = False,
        num_memory_tokens = None,
        memory_tokens_interspersed_every = None,
        tie_embedding = False,
        logits_dim = None,
        return_only_embed = False,
        num_output_heads = 1,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False,
        l2norm_embed = False,
        recycling = False,
        train_max_recycle_steps = 4,
        emb_frac_gradient = 1.,
        attn_z_loss_weight = 1e-4,
        average_pool_embed = False,
        use_cls_token = False,
        num_cls_tokens = 1,
        attn_pool = False,
        num_pooled_tokens = 1,
        attn_pool_depth = 1,
        dim_pooled_tokens = None,
        squeeze_out_last_dim = False,
        token_emb: TokenEmbedding | None = None,
        mixture_of_softmax = False,
        mixture_of_softmax_k = 4,
        sigsoftmax_logits = False,
        ff_deep_embed = False,
        to_logits: Module | None = None,
        add_continuous_pred_head = False,
        input_not_include_cache = False
    ):

I/O Contract

Inputs

Parameter Type Required Description
num_tokens int Yes Vocabulary size (must include mask token). Determines the embedding table size and the output logits dimension.
max_seq_len int Yes Fixed sequence length for generation. Determines the size of the absolute positional embedding table (if used).
attn_layers Encoder Yes Encoder attention layers (causal=False). Pass an Encoder(dim=..., depth=..., heads=...) instance.

Outputs

The constructed TransformerWrapper instance is a torch.nn.Module. When called in forward mode:

Input Type Description
x torch.LongTensor of shape (batch, seq_len) Integer token IDs, potentially containing mask tokens at positions to be predicted.
Output Type Description
logits torch.FloatTensor of shape (batch, seq_len, num_tokens) Prediction logits over the vocabulary at each position. For masked token prediction, the logits at masked positions are used to predict the original tokens.

Usage Examples

Basic Bidirectional Encoder for Masked Prediction

from x_transformers import TransformerWrapper, Encoder

NUM_TOKENS = 256
MASK_TOKEN_ID = NUM_TOKENS  # mask token is last

model = TransformerWrapper(
    num_tokens = NUM_TOKENS + 1,  # +1 for mask token
    max_seq_len = 512,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

This creates a bidirectional encoder with:

  • 257 tokens (256 base vocabulary + 1 mask token).
  • 512 fixed sequence length for generation.
  • 512-dimensional hidden states.
  • 6 transformer layers with bidirectional (non-causal) self-attention.
  • 8 attention heads (64 dimensions per head).

Encoder with Rotary Embeddings and NonAutoregressiveWrapper

from x_transformers import TransformerWrapper, Encoder
from x_transformers import NonAutoregressiveWrapper

NUM_TOKENS = 1024
MASK_TOKEN_ID = NUM_TOKENS

model = TransformerWrapper(
    num_tokens = NUM_TOKENS + 1,
    max_seq_len = 256,
    attn_layers = Encoder(
        dim = 512,
        depth = 8,
        heads = 8,
        rotary_pos_emb = True
    )
)

# Wrap for non-autoregressive training and generation
nar_wrapper = NonAutoregressiveWrapper(
    model,
    mask_id = MASK_TOKEN_ID,
    steps = 18
)

This demonstrates the full pattern for building a MaskGIT-style non-autoregressive model:

  • The encoder provides bidirectional context for masked token prediction.
  • Rotary position embeddings are used for relative position awareness.
  • The NonAutoregressiveWrapper handles masking, loss computation, and iterative refinement generation over 18 steps.

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment