Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Lucidrains X transformers Segment Level Recurrence

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Long_Context
Last Updated 2026-02-08 18:00 GMT

Overview

Technique that enables transformers to process sequences longer than their maximum context window by segmenting the input and propagating hidden state memories across segment boundaries.

Description

Segment-Level Recurrence, introduced in Transformer-XL, solves the fixed-length context limitation of standard transformers. The input sequence is divided into non-overlapping segments of max_seq_len tokens. Each segment is processed by the transformer, and the hidden states from each layer are cached as "memories" (mems). When processing the next segment, these cached hidden states are prepended to the current segment's hidden states in the attention computation, allowing the model to attend to tokens from previous segments without recomputation. This creates an effective context window much larger than max_seq_len while maintaining constant memory and computation per segment. During training, the loss from each segment is weighted proportionally to its length.

Usage

Use this principle when working with sequences that exceed the transformer's maximum sequence length, such as long documents, books, code repositories, or streaming applications. The approach is complementary to relative positional encodings (like rotary embeddings) that generalize across segment boundaries.

Theoretical Basis

The recurrence mechanism across segments:

Pseudo-code Logic:

# Abstract algorithm (NOT real implementation)
mems = None

# Split sequence into segments
segments = split(sequence, max_seq_len)

total_loss = 0
for segment, labels in segments:
    # Process segment with memories from previous segment
    logits, new_mems = transformer(segment, mems=mems, return_mems=True)

    # Compute loss weighted by segment length
    loss = cross_entropy(logits, labels) * (len(segment) / total_len)
    total_loss += loss

    # Cache hidden states for next segment (detached from graph)
    mems = [h[-max_mem_len:].detach() for h in new_mems]

The effective context length is:

Leff=Lseg+Lmem×Nsegments

where Lseg is the segment length, Lmem is the memory length, and memories accumulate across segments.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment