Principle:Lucidrains X transformers Bidirectional Encoder Configuration
Metadata
| Field | Value |
|---|---|
| Sources | Paper: BERT: Pre-training of Deep Bidirectional Transformers; Paper: MaskGIT; Repo: x-transformers |
| Domains | Deep_Learning, NLP, Model_Architecture |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Architecture configuration pattern for defining bidirectional (non-causal) encoder transformer models suitable for masked token prediction and iterative refinement generation.
Description
Unlike causal decoders, bidirectional encoders use non-causal self-attention where every position can attend to every other position. This is essential for masked token prediction tasks where the model predicts masked tokens based on surrounding context from both directions.
In the x-transformers library, the bidirectional encoder is composed from two principal classes:
AttentionLayers-- The core class that constructs the stack of attention and feedforward layers. It acceptsdim,depth,heads, normalization options, positional encoding flags, and dozens of other configuration knobs.Encoder-- A thin convenience subclass ofAttentionLayersthat forcescausal=False. This ensures that no causal mask is applied during self-attention, allowing every token to attend to all other tokens in the sequence.
When wrapped with TransformerWrapper, the Encoder produces a model that takes token IDs (potentially including mask tokens) as input and outputs logits over the vocabulary for each position. The vocabulary size must include the mask token ID, since the mask token is treated as a regular token in the embedding table.
This configuration is the foundational first step for building non-autoregressive generation models. The NonAutoregressiveWrapper relies on a bidirectional encoder to iteratively predict masked tokens, refining the output over multiple passes in the style of MaskGIT.
Usage
Use this principle when building models that require bidirectional context for prediction. The configuration determines:
- Model capacity -- Controlled primarily by
dim,depth, andheads. Typical configurations mirror those of causal decoders (e.g.,dim=512, depth=6, heads=8for small models). - Vocabulary size -- Must be large enough to include the mask token ID. For example, if the base vocabulary has 256 tokens and the mask token ID is 256, then
num_tokensmust be at least 257. - Sequence length --
max_seq_lensets the fixed generation length forNonAutoregressiveWrapper, defining how many tokens the model produces per generation step.
When to apply
- You are building a non-autoregressive generation model (MaskGIT-style iterative refinement).
- You are building a BERT-like masked language model for pre-training or fine-tuning.
- You need bidirectional context for any token-level prediction task where the model should see both left and right context.
When not to apply
- You are building an autoregressive (left-to-right) language model (use
Decoderinstead ofEncoder). - You are building an encoder-decoder model for sequence-to-sequence tasks (use
XTransformerwith bothEncoderandDecoder). - You need causal masking to prevent information leakage from future tokens.
Theoretical Basis
Bidirectional Self-Attention
The defining property of a bidirectional encoder is the absence of a causal mask. During self-attention, every position i can attend to every position j in the sequence, regardless of their relative order:
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
Unlike the causal decoder, no mask matrix is added to the attention scores. This means:
Position i attends to all positions j in [0, ..., n-1]
No upper-triangular masking applied.
This allows every token to incorporate information from both left and right context, producing richer contextual representations compared to left-to-right only models.
Masked Token Prediction
For masked token prediction tasks, a subset of input tokens is replaced with a special [MASK] token. The model is trained to predict the original tokens at masked positions based on the surrounding (unmasked) context:
Input: [the, MASK, sat, on, MASK, mat]
Target: [ _, cat, _, _, the, _]
The bidirectional attention mechanism is critical here because:
- The prediction of MASK at position 1 benefits from seeing both "the" (left context) and "sat on" (right context).
- The prediction of MASK at position 4 benefits from seeing "on" (left context) and "mat" (right context).
A causal decoder would only see left context for each masked position, resulting in substantially weaker predictions.
Iterative Refinement (MaskGIT)
In the MaskGIT paradigm (Chang et al., 2022), generation proceeds iteratively:
- Start with a fully masked sequence: all tokens are [MASK].
- Predict all masked positions simultaneously using the bidirectional encoder.
- Unmask the most confident predictions (highest probability tokens).
- Repeat steps 2-3 with the remaining masked positions until all tokens are unmasked.
This approach leverages bidirectional attention at each refinement step, allowing the model to condition on previously unmasked tokens from any position in the sequence, not just preceding positions.