Implementation:Lucidrains X transformers BeliefStateWrapper
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Generative_Models |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for bidirectional belief state transformer training and fill-in-the-middle generation provided by the x-transformers library.
Description
The BeliefStateWrapper class implements the Belief State Transformer from Hu et al. (2024). It wraps a forward (causal) decoder and a backward (reverse-causal) decoder to jointly train a model that can predict tokens from both directions. The key idea is that for any pair of forward prefix and backward suffix positions, the model predicts the next token in both directions. This enables fill-in-the-middle generation where the model can generate text conditioned on both a prefix and a suffix goal. The wrapper also supports optional distance prediction between forward and backward positions, and distance-conditioned generation.
Usage
Import this class when you need to train a transformer that can perform fill-in-the-middle generation, or when you want goal-conditioned text generation where the model generates text to bridge a prefix and a suffix. This is particularly useful for code infilling, constrained text generation, and planning tasks.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/belief_state_wrapper.py
- Lines: 79-432
Signature
class BeliefStateWrapper(Module):
def __init__(
self,
forward_decoder: TransformerWrapper,
backward_decoder: TransformerWrapper | None = None,
train_frac_forward_backward_pairs: float = 1.,
text_head: Module | None = None,
backward_ar_loss_weight: float = 1.,
pred_distance: bool = False,
pred_distance_loss_weight: float = 1.,
cond_on_distance: bool = False,
cond_on_distance_prob: float = 0.5,
max_pred_distance: int | None = None
):
"""
Args:
forward_decoder: TransformerWrapper for causal forward decoding.
backward_decoder: TransformerWrapper for reverse-causal decoding. If None, reuses forward_decoder.
train_frac_forward_backward_pairs: Fraction of valid forward-backward pairs to train on (for memory efficiency).
text_head: Custom prediction head. If None, a default MLP is created.
backward_ar_loss_weight: Weight for backward autoregressive loss relative to forward loss.
pred_distance: Whether to predict the distance between forward and backward positions.
pred_distance_loss_weight: Weight for distance prediction loss.
cond_on_distance: Whether to condition generation on distance.
cond_on_distance_prob: Probability of conditioning on distance during training.
max_pred_distance: Maximum distance to predict (defaults to max_seq_len).
"""
Import
from x_transformers.belief_state_wrapper import BeliefStateWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| forward_decoder | TransformerWrapper | Yes | Causal forward decoder model |
| backward_decoder | TransformerWrapper | No | Reverse-causal decoder (defaults to forward_decoder) |
| train_frac_forward_backward_pairs | float | No | Fraction of pairs to train on (default 1.0) |
| backward_ar_loss_weight | float | No | Loss weight for backward direction (default 1.0) |
| pred_distance | bool | No | Enable distance prediction auxiliary task |
forward() Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| seq | Tensor (b, n) | Yes | Input token sequence |
| lens | Tensor (b,) | No | Actual lengths per batch element for variable-length sequences |
| loss_weight_by_fb_indices | callable | No | Function to weight loss by forward-backward pair indices |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | Tensor (scalar) | Combined cross-entropy loss for forward and backward predictions |
| generate_with_suffix_cond() returns | Tensor (b, seq_len) | Generated token sequence conditioned on optional suffix |
Usage Examples
Basic Training
import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.belief_state_wrapper import BeliefStateWrapper
# Create forward and backward decoders
forward_dec = TransformerWrapper(
num_tokens=256,
max_seq_len=512,
attn_layers=Decoder(dim=256, depth=6, heads=8)
)
backward_dec = TransformerWrapper(
num_tokens=256,
max_seq_len=512,
attn_layers=Decoder(dim=256, depth=6, heads=8)
)
# Wrap with belief state
model = BeliefStateWrapper(
forward_decoder=forward_dec,
backward_decoder=backward_dec,
train_frac_forward_backward_pairs=0.5,
backward_ar_loss_weight=1.0
)
# Training step
seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()
Fill-in-the-Middle Generation
# Generate with suffix conditioning
prompt = torch.randint(0, 256, (1, 10))
suffix = torch.randint(0, 256, (1, 5))
generated = model.generate_with_suffix_cond(
prompts=prompt,
seq_len=50,
suffix=suffix,
temperature=1.25,
cache_kv=True
)