Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers BeliefStateWrapper

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Generative_Models
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for bidirectional belief state transformer training and fill-in-the-middle generation provided by the x-transformers library.

Description

The BeliefStateWrapper class implements the Belief State Transformer from Hu et al. (2024). It wraps a forward (causal) decoder and a backward (reverse-causal) decoder to jointly train a model that can predict tokens from both directions. The key idea is that for any pair of forward prefix and backward suffix positions, the model predicts the next token in both directions. This enables fill-in-the-middle generation where the model can generate text conditioned on both a prefix and a suffix goal. The wrapper also supports optional distance prediction between forward and backward positions, and distance-conditioned generation.

Usage

Import this class when you need to train a transformer that can perform fill-in-the-middle generation, or when you want goal-conditioned text generation where the model generates text to bridge a prefix and a suffix. This is particularly useful for code infilling, constrained text generation, and planning tasks.

Code Reference

Source Location

Signature

class BeliefStateWrapper(Module):
    def __init__(
        self,
        forward_decoder: TransformerWrapper,
        backward_decoder: TransformerWrapper | None = None,
        train_frac_forward_backward_pairs: float = 1.,
        text_head: Module | None = None,
        backward_ar_loss_weight: float = 1.,
        pred_distance: bool = False,
        pred_distance_loss_weight: float = 1.,
        cond_on_distance: bool = False,
        cond_on_distance_prob: float = 0.5,
        max_pred_distance: int | None = None
    ):
        """
        Args:
            forward_decoder: TransformerWrapper for causal forward decoding.
            backward_decoder: TransformerWrapper for reverse-causal decoding. If None, reuses forward_decoder.
            train_frac_forward_backward_pairs: Fraction of valid forward-backward pairs to train on (for memory efficiency).
            text_head: Custom prediction head. If None, a default MLP is created.
            backward_ar_loss_weight: Weight for backward autoregressive loss relative to forward loss.
            pred_distance: Whether to predict the distance between forward and backward positions.
            pred_distance_loss_weight: Weight for distance prediction loss.
            cond_on_distance: Whether to condition generation on distance.
            cond_on_distance_prob: Probability of conditioning on distance during training.
            max_pred_distance: Maximum distance to predict (defaults to max_seq_len).
        """

Import

from x_transformers.belief_state_wrapper import BeliefStateWrapper

I/O Contract

Inputs

Name Type Required Description
forward_decoder TransformerWrapper Yes Causal forward decoder model
backward_decoder TransformerWrapper No Reverse-causal decoder (defaults to forward_decoder)
train_frac_forward_backward_pairs float No Fraction of pairs to train on (default 1.0)
backward_ar_loss_weight float No Loss weight for backward direction (default 1.0)
pred_distance bool No Enable distance prediction auxiliary task

forward() Inputs

Name Type Required Description
seq Tensor (b, n) Yes Input token sequence
lens Tensor (b,) No Actual lengths per batch element for variable-length sequences
loss_weight_by_fb_indices callable No Function to weight loss by forward-backward pair indices

Outputs

Name Type Description
forward() returns Tensor (scalar) Combined cross-entropy loss for forward and backward predictions
generate_with_suffix_cond() returns Tensor (b, seq_len) Generated token sequence conditioned on optional suffix

Usage Examples

Basic Training

import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.belief_state_wrapper import BeliefStateWrapper

# Create forward and backward decoders
forward_dec = TransformerWrapper(
    num_tokens=256,
    max_seq_len=512,
    attn_layers=Decoder(dim=256, depth=6, heads=8)
)

backward_dec = TransformerWrapper(
    num_tokens=256,
    max_seq_len=512,
    attn_layers=Decoder(dim=256, depth=6, heads=8)
)

# Wrap with belief state
model = BeliefStateWrapper(
    forward_decoder=forward_dec,
    backward_decoder=backward_dec,
    train_frac_forward_backward_pairs=0.5,
    backward_ar_loss_weight=1.0
)

# Training step
seq = torch.randint(0, 256, (4, 128))
loss = model(seq)
loss.backward()

Fill-in-the-Middle Generation

# Generate with suffix conditioning
prompt = torch.randint(0, 256, (1, 10))
suffix = torch.randint(0, 256, (1, 5))

generated = model.generate_with_suffix_cond(
    prompts=prompt,
    seq_len=50,
    suffix=suffix,
    temperature=1.25,
    cache_kv=True
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment