Principle:Lucidrains X transformers Belief State Training
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Generative_Models |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Mechanism that trains a language model to jointly predict forward and backward tokens from all valid prefix-suffix pairs, enabling fill-in-the-middle generation.
Description
Belief State Training is a bidirectional training objective introduced by Hu et al. (2024). Instead of training only a forward autoregressive model, it simultaneously trains a forward (causal) and backward (reverse-causal) decoder. For every valid pair of forward position i and backward position j (where j - i >= 2), the model combines the forward embedding at position i with the backward embedding at position j to predict both the next forward token (at i+1) and the previous backward token (at j-1). This produces a "belief state" that captures both left and right context, enabling the model to generate text that must bridge between a given prefix and suffix.
Usage
Use this principle when designing models that need fill-in-the-middle capability, goal-conditioned generation, or infilling tasks. It is appropriate when the model must generate text that is constrained by both preceding and following context, such as code completion between existing blocks or constrained narrative generation.
Theoretical Basis
The belief state objective trains on all valid forward-backward index pairs:
Pseudo-code Logic:
# Abstract algorithm (NOT real implementation)
# For a sequence of length N:
for i in range(N): # forward position
for j in range(i+2, N+1): # backward position (at least 2 apart)
fwd_embed = forward_decoder_embed[i]
bwd_embed = backward_decoder_embed[j]
combined = concat(fwd_embed, bwd_embed)
pred_fwd_token, pred_bwd_token = text_head(combined)
loss += cross_entropy(pred_fwd_token, seq[i+1])
loss += cross_entropy(pred_bwd_token, seq[j-1])
The key insight is that the combination of forward prefix state and backward suffix state forms a "belief state" sufficient for predicting the next tokens in both directions. At inference time, the suffix embedding provides a goal signal for fill-in-the-middle generation.