Principle:Lucidrains X transformers Segment Level Recurrence
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Long_Context |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Technique that enables transformers to process sequences longer than their maximum context window by segmenting the input and propagating hidden state memories across segment boundaries.
Description
Segment-Level Recurrence, introduced in Transformer-XL, solves the fixed-length context limitation of standard transformers. The input sequence is divided into non-overlapping segments of max_seq_len tokens. Each segment is processed by the transformer, and the hidden states from each layer are cached as "memories" (mems). When processing the next segment, these cached hidden states are prepended to the current segment's hidden states in the attention computation, allowing the model to attend to tokens from previous segments without recomputation. This creates an effective context window much larger than max_seq_len while maintaining constant memory and computation per segment. During training, the loss from each segment is weighted proportionally to its length.
Usage
Use this principle when working with sequences that exceed the transformer's maximum sequence length, such as long documents, books, code repositories, or streaming applications. The approach is complementary to relative positional encodings (like rotary embeddings) that generalize across segment boundaries.
Theoretical Basis
The recurrence mechanism across segments:
Pseudo-code Logic:
# Abstract algorithm (NOT real implementation)
mems = None
# Split sequence into segments
segments = split(sequence, max_seq_len)
total_loss = 0
for segment, labels in segments:
# Process segment with memories from previous segment
logits, new_mems = transformer(segment, mems=mems, return_mems=True)
# Compute loss weighted by segment length
loss = cross_entropy(logits, labels) * (len(segment) / total_len)
total_loss += loss
# Cache hidden states for next segment (detached from graph)
mems = [h[-max_mem_len:].detach() for h in new_mems]
The effective context length is:
where is the segment length, is the memory length, and memories accumulate across segments.