Implementation:Lucidrains X transformers TransformerWrapper Encoder Init
Metadata
| Field | Value |
|---|---|
| Source | Repo: x-transformers |
| Domains | NLP, Model_Architecture |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for configuring bidirectional encoder transformer models for masked token prediction provided by the x-transformers library.
Description
TransformerWrapper with an Encoder attention layer produces a bidirectional transformer model suitable for masked token prediction and non-autoregressive generation. This uses the same TransformerWrapper class as the decoder configuration but with Encoder (which sets causal=False) as the attention layers.
The key differences from the causal decoder configuration are:
- No causal mask -- The
Encoderclass forcescausal=False, so every position attends to every other position in the sequence. - Vocabulary size must include the mask token -- Since the mask token is fed as input to the model,
num_tokensmust be large enough to include the mask token ID. For example, if the base vocabulary has 256 tokens and the mask token ID is 256, thennum_tokensmust be at least 257. max_seq_lensets the fixed generation length -- When used withNonAutoregressiveWrapper, themax_seq_lenparameter defines the fixed number of tokens produced during generation.
The model takes token sequences (with mask tokens at positions to be predicted) and outputs logits over the vocabulary for each position. During training, the loss is computed only at masked positions. During inference, the model iteratively predicts and unmasks tokens in a MaskGIT-style refinement loop.
The wrapper also supports all the same optional features as the decoder configuration, including memory tokens, embedding dropout, tied embeddings, and advanced features such as recycling and mixture of softmax.
Usage
Import TransformerWrapper and Encoder when building a non-autoregressive masked prediction model. Configure the model via:
num_tokens-- Vocabulary size, which must include the mask token (required).max_seq_len-- Fixed sequence length for generation (required).attn_layers-- AnEncoder(dim=..., depth=..., heads=...)instance that defines the bidirectional transformer stack (required).
All other parameters are optional and provide fine-grained control over embeddings, output heads, and advanced features.
Code Reference
Repository
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/x_transformers.py
|
| Lines | L3266-3308 (TransformerWrapper.__init__), L3090-3093 (Encoder)
|
Import
from x_transformers import TransformerWrapper, Encoder
Encoder Class
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal = False, **kwargs)
The Encoder class is a minimal subclass of AttentionLayers that enforces causal=False. It raises an assertion error if the user attempts to explicitly pass causal as a keyword argument, since the encoder is always non-causal by definition. This is the mirror image of the Decoder class, which enforces causal=True.
TransformerWrapper.__init__ Signature
The TransformerWrapper signature is identical to the decoder configuration (see L3266-3308). The same class is used for both encoder and decoder models; the difference lies entirely in the attn_layers argument:
class TransformerWrapper(Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers: AttentionLayers,
embed_num_tokens: dict[str, int] = dict(),
emb_dim = None,
max_mem_len = 0,
shift_mem_down = 0,
emb_dropout = 0.,
post_emb_norm = False,
num_memory_tokens = None,
memory_tokens_interspersed_every = None,
tie_embedding = False,
logits_dim = None,
return_only_embed = False,
num_output_heads = 1,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
l2norm_embed = False,
recycling = False,
train_max_recycle_steps = 4,
emb_frac_gradient = 1.,
attn_z_loss_weight = 1e-4,
average_pool_embed = False,
use_cls_token = False,
num_cls_tokens = 1,
attn_pool = False,
num_pooled_tokens = 1,
attn_pool_depth = 1,
dim_pooled_tokens = None,
squeeze_out_last_dim = False,
token_emb: TokenEmbedding | None = None,
mixture_of_softmax = False,
mixture_of_softmax_k = 4,
sigsoftmax_logits = False,
ff_deep_embed = False,
to_logits: Module | None = None,
add_continuous_pred_head = False,
input_not_include_cache = False
):
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
num_tokens |
int |
Yes | Vocabulary size (must include mask token). Determines the embedding table size and the output logits dimension. |
max_seq_len |
int |
Yes | Fixed sequence length for generation. Determines the size of the absolute positional embedding table (if used). |
attn_layers |
Encoder |
Yes | Encoder attention layers (causal=False). Pass an Encoder(dim=..., depth=..., heads=...) instance.
|
Outputs
The constructed TransformerWrapper instance is a torch.nn.Module. When called in forward mode:
| Input | Type | Description |
|---|---|---|
x |
torch.LongTensor of shape (batch, seq_len) |
Integer token IDs, potentially containing mask tokens at positions to be predicted. |
| Output | Type | Description |
| logits | torch.FloatTensor of shape (batch, seq_len, num_tokens) |
Prediction logits over the vocabulary at each position. For masked token prediction, the logits at masked positions are used to predict the original tokens. |
Usage Examples
Basic Bidirectional Encoder for Masked Prediction
from x_transformers import TransformerWrapper, Encoder
NUM_TOKENS = 256
MASK_TOKEN_ID = NUM_TOKENS # mask token is last
model = TransformerWrapper(
num_tokens = NUM_TOKENS + 1, # +1 for mask token
max_seq_len = 512,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
This creates a bidirectional encoder with:
- 257 tokens (256 base vocabulary + 1 mask token).
- 512 fixed sequence length for generation.
- 512-dimensional hidden states.
- 6 transformer layers with bidirectional (non-causal) self-attention.
- 8 attention heads (64 dimensions per head).
Encoder with Rotary Embeddings and NonAutoregressiveWrapper
from x_transformers import TransformerWrapper, Encoder
from x_transformers import NonAutoregressiveWrapper
NUM_TOKENS = 1024
MASK_TOKEN_ID = NUM_TOKENS
model = TransformerWrapper(
num_tokens = NUM_TOKENS + 1,
max_seq_len = 256,
attn_layers = Encoder(
dim = 512,
depth = 8,
heads = 8,
rotary_pos_emb = True
)
)
# Wrap for non-autoregressive training and generation
nar_wrapper = NonAutoregressiveWrapper(
model,
mask_id = MASK_TOKEN_ID,
steps = 18
)
This demonstrates the full pattern for building a MaskGIT-style non-autoregressive model:
- The encoder provides bidirectional context for masked token prediction.
- Rotary position embeddings are used for relative position awareness.
- The
NonAutoregressiveWrapperhandles masking, loss computation, and iterative refinement generation over 18 steps.