Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding LLM Embedder SRLM

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Large Language Models, Information Retrieval
Last Updated 2026-02-09 00:00 GMT

Overview

A Self-Retrieval Language Model that enables LLMs to retrieve relevant context from their own previous generations during inference and training.

Description

SelfRetrievalLM (SRLM) extends standard language models with the ability to retrieve and attend to relevant chunks from earlier in the context. Instead of processing a long context sequentially, SRLM:

1. Chunks the context into fixed-size segments (default 64 tokens) 2. Retrieves relevant chunks using dense retrieval, BM25, or oracle selection 3. Integrates retrieved context by concatenation or replacement 4. Generates next tokens conditioned on retrieved context

This enables processing contexts far longer than the model's training window (e.g., 32K tokens) while maintaining bounded memory and compute. The model supports multiple retrieval methods (dense embedder, BM25, random, oracle/optimal), flexible integration strategies (concat retrieved chunks before window, or replace least relevant chunks), and two use cases: chunk-based retrieval for long documents, and history retrieval for multi-turn dialogue.

The architecture is particularly useful for long-context language modeling tasks where full attention over all tokens is computationally prohibitive.

Usage

Use SRLM when working with contexts exceeding the model's position embedding length, when you need efficient processing of very long documents (100K+ tokens), or when building dialogue systems that selectively attend to relevant conversation history rather than all previous turns.

Code Reference

Source Location

Signature

class SelfRetrievalLM(LM):
    def __init__(
        self,
        retriever=None,
        context_window_size: int = 2048,
        chunk_size: int = 64,
        key_num: int = 1,
        chunk_batch_size: int = 2,
        add_key_continuation: bool = False,
        retrieval_method: str = "dense",
        order_method: str = "sequential",
        integrate_method: str = "concat",
        instruction: Dict = None,
        add_sep: Optional[List[int]] = None,
        debug_retrieval: bool = False,
        **kwds
    ):
        """Initialize self-retrieval language model"""

    def forward(self, input_ids, attention_mask, labels) -> SRLMOutput:
        """Forward pass with chunk-based retrieval"""

    def forward_with_history_retrieval(
        self, query: np.ndarray, history: np.ndarray,
        answer: np.ndarray, history_mask: torch.Tensor
    ) -> SRLMOutput:
        """Forward pass with dialogue history retrieval"""

    def compute_perplexity(self, dataloader) -> float:
        """Compute perplexity over long inputs"""

Import

from research.llm_embedder.src.lm.modeling_srlm import SelfRetrievalLM, SRLMOutput

I/O Contract

Inputs

Name Type Required Description
input_ids Tensor Yes Token IDs (batch_size, seq_len)
attention_mask Tensor Yes Attention mask
labels Tensor Yes Labels for LM loss (-100 for ignored tokens)
retriever object No Dense retriever with encode() method
context_window_size int No Max context length (default: 2048)
chunk_size int No Chunk size in tokens (default: 64)
key_num int No Number of chunks to retrieve (default: 1)

Outputs

Name Type Description
loss Tensor Language modeling loss
logits Tensor Model logits (optional)
past_key_values Tuple Cached key-value states (optional)

Architecture

Chunk-Based Retrieval

# Input: [pad] [pad] [chunk0] [chunk1] ... [chunkN-3] [chunkN-2] [chunkN-1] [chunkN]
#                     |<------- corpus -------->|       |<--- window --->|

# Process:
# 1. Split into chunks (64 tokens each)
# 2. Query: chunkN-2 (target: chunkN-1)
# 3. Keys: chunk0...chunkN-3 (earlier chunks)
# 4. Retrieve top-k chunks + continuations (k*2*64 tokens)
# 5. Integrate: [retrieved] [fixed_window]
#              [k*128 tokens] [window_size tokens]
# 6. Compute loss only on chunkN-1

# Retrieval Methods:
# - dense: Encode chunks with retriever, compute similarity
# - bm25: Use BM25 scoring on chunk text
# - random: Random chunk selection
# - oracle: Exhaustive search for chunk minimizing loss (training only)

Integration Strategies

if integrate_method == "concat":
    # Prepend retrieved chunks before window
    input_ids = [retrieved_chunks] + [fixed_context]
    # Window: [k*128] [2048] = 2048 + k*128 tokens total

elif integrate_method == "replace":
    # Replace earliest chunks in window with retrieved chunks
    fixed_context = window[k*128:]  # Keep last (window - k*128) tokens
    input_ids = [retrieved_chunks] + [fixed_context]
    # Window: [k*128] [(2048 - k*128)] = 2048 tokens total (fits in memory)

Ordering Strategies

if order_method == "sequential":
    # Order by position in document (early to late)
    retrieved_indices = retrieved_indices.sort(-1)[0]

elif order_method == "relevance":
    # Order by relevance (most relevant first)
    retrieved_indices = retrieved_indices.flip(dims=(-1,))
    # Default FAISS returns descending relevance, flip to ascending

Usage Examples

Initialize SRLM

from transformers import AutoTokenizer, AutoModelForCausalLM
from FlagEmbedding import FlagModel
from research.llm_embedder.src.lm.modeling_srlm import SelfRetrievalLM

# Load base LM and retriever
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
retriever = FlagModel('BAAI/llm-embedder', use_fp16=True)

# Wrap with SRLM
model = SelfRetrievalLM(
    model=base_model,
    tokenizer=tokenizer,
    retriever=retriever,
    context_window_size=2048,  # 2K context window
    chunk_size=64,             # 64-token chunks
    key_num=3,                 # Retrieve 3 chunks (3*2*64 = 384 tokens)
    chunk_batch_size=8,        # Process 8 target chunks per batch
    retrieval_method='dense',
    order_method='sequential',
    integrate_method='concat',
    instruction={
        'query': 'Represent this text chunk for retrieval: ',
        'key': 'Represent this text chunk for retrieval: '
    }
)

Long Document Modeling

# Prepare long document (e.g., 32K tokens)
document = "The quick brown fox..." * 5000  # Very long text
inputs = tokenizer(document, return_tensors='pt', truncation=False)

# SRLM automatically chunks and retrieves
# Context window: 2048, but can process 32K+ tokens
outputs = model(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    labels=inputs['input_ids']  # Labels for language modeling
)

print(f"Loss: {outputs.loss.item():.4f}")
# Only computes loss on target chunks, not retrieved context

# Compute perplexity
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collator)
perplexity = model.compute_perplexity(dataloader)
print(f"Perplexity: {perplexity:.2f}")

Multi-Turn Dialogue with History Retrieval

import numpy as np

# Dialogue history (list of turns)
history = np.array([
    ["Speaker 1: Hello, how are you?\nSpeaker 2: I'm good, thanks!"],
    ["Speaker 1: What's the weather?\nSpeaker 2: It's sunny today."],
    ["Speaker 1: Do you like sports?\nSpeaker 2: Yes, I love basketball."],
    # ... many more turns
])

query = np.array(["What did we talk about earlier?"])
answer = np.array(["We discussed the weather and sports."])

# Create history mask (1 = valid turn, 0 = padding)
history_mask = torch.ones(1, history.shape[1])

# SRLM retrieves relevant history turns
outputs = model.forward_with_history_retrieval(
    query=query,
    history=history,
    answer=answer,
    history_mask=history_mask
)

# Model retrieves turns about "weather" and "sports"
# Generates answer conditioned on retrieved history

Different Retrieval Methods

# Dense retrieval (default)
model_dense = SelfRetrievalLM(retriever=retriever, retrieval_method='dense', ...)

# BM25 retrieval
from pyserini.search import LuceneSearcher
bm25_searcher = LuceneSearcher(...)
model_bm25 = SelfRetrievalLM(retriever=bm25_searcher, retrieval_method='bm25', ...)

# Random retrieval (baseline)
model_random = SelfRetrievalLM(retrieval_method='random', ...)

# No retrieval (full context)
model_no_retrieval = SelfRetrievalLM(retrieval_method='no', ...)

# Oracle retrieval (exhaustive search, training only)
model_oracle = SelfRetrievalLM(
    retrieval_method='oracle',
    key_num=1,  # Must be 1 for oracle
    debug_retrieval=True  # Print selected chunks
)

Concat vs Replace Integration

# Concat: Prepend retrieved chunks (longer context)
model_concat = SelfRetrievalLM(
    context_window_size=2048,
    key_num=3,
    integrate_method='concat'
)
# Total context: 2048 + 3*2*64 = 2432 tokens

# Replace: Replace early chunks (fixed context length)
model_replace = SelfRetrievalLM(
    context_window_size=2048,
    key_num=3,
    integrate_method='replace'
)
# Total context: 2048 tokens (3*2*64 replaced, rest kept)

Debug Retrieval Decisions

# Enable debug mode to see what's retrieved
model = SelfRetrievalLM(
    retriever=retriever,
    debug_retrieval=True,
    ...
)

# During training/inference, prints:
# ***Indices***
# [42, 15, 89]  # Retrieved chunk indices
# ***Query***
# "The fox jumped over..."
# ***Retrieved***
# ["The quick brown fox...", "Lazy dogs often...", "Foxes are mammals..."]
# ***Target***
# "the lazy dog."

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment