Implementation:FlagOpen FlagEmbedding LLM Embedder SRLM
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Large Language Models, Information Retrieval |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A Self-Retrieval Language Model that enables LLMs to retrieve relevant context from their own previous generations during inference and training.
Description
SelfRetrievalLM (SRLM) extends standard language models with the ability to retrieve and attend to relevant chunks from earlier in the context. Instead of processing a long context sequentially, SRLM:
1. Chunks the context into fixed-size segments (default 64 tokens) 2. Retrieves relevant chunks using dense retrieval, BM25, or oracle selection 3. Integrates retrieved context by concatenation or replacement 4. Generates next tokens conditioned on retrieved context
This enables processing contexts far longer than the model's training window (e.g., 32K tokens) while maintaining bounded memory and compute. The model supports multiple retrieval methods (dense embedder, BM25, random, oracle/optimal), flexible integration strategies (concat retrieved chunks before window, or replace least relevant chunks), and two use cases: chunk-based retrieval for long documents, and history retrieval for multi-turn dialogue.
The architecture is particularly useful for long-context language modeling tasks where full attention over all tokens is computationally prohibitive.
Usage
Use SRLM when working with contexts exceeding the model's position embedding length, when you need efficient processing of very long documents (100K+ tokens), or when building dialogue systems that selectively attend to relevant conversation history rather than all previous turns.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_embedder/src/lm/modeling_srlm.py
- Lines: 1-554
Signature
class SelfRetrievalLM(LM):
def __init__(
self,
retriever=None,
context_window_size: int = 2048,
chunk_size: int = 64,
key_num: int = 1,
chunk_batch_size: int = 2,
add_key_continuation: bool = False,
retrieval_method: str = "dense",
order_method: str = "sequential",
integrate_method: str = "concat",
instruction: Dict = None,
add_sep: Optional[List[int]] = None,
debug_retrieval: bool = False,
**kwds
):
"""Initialize self-retrieval language model"""
def forward(self, input_ids, attention_mask, labels) -> SRLMOutput:
"""Forward pass with chunk-based retrieval"""
def forward_with_history_retrieval(
self, query: np.ndarray, history: np.ndarray,
answer: np.ndarray, history_mask: torch.Tensor
) -> SRLMOutput:
"""Forward pass with dialogue history retrieval"""
def compute_perplexity(self, dataloader) -> float:
"""Compute perplexity over long inputs"""
Import
from research.llm_embedder.src.lm.modeling_srlm import SelfRetrievalLM, SRLMOutput
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | Tensor | Yes | Token IDs (batch_size, seq_len) |
| attention_mask | Tensor | Yes | Attention mask |
| labels | Tensor | Yes | Labels for LM loss (-100 for ignored tokens) |
| retriever | object | No | Dense retriever with encode() method |
| context_window_size | int | No | Max context length (default: 2048) |
| chunk_size | int | No | Chunk size in tokens (default: 64) |
| key_num | int | No | Number of chunks to retrieve (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Language modeling loss |
| logits | Tensor | Model logits (optional) |
| past_key_values | Tuple | Cached key-value states (optional) |
Architecture
Chunk-Based Retrieval
# Input: [pad] [pad] [chunk0] [chunk1] ... [chunkN-3] [chunkN-2] [chunkN-1] [chunkN]
# |<------- corpus -------->| |<--- window --->|
# Process:
# 1. Split into chunks (64 tokens each)
# 2. Query: chunkN-2 (target: chunkN-1)
# 3. Keys: chunk0...chunkN-3 (earlier chunks)
# 4. Retrieve top-k chunks + continuations (k*2*64 tokens)
# 5. Integrate: [retrieved] [fixed_window]
# [k*128 tokens] [window_size tokens]
# 6. Compute loss only on chunkN-1
# Retrieval Methods:
# - dense: Encode chunks with retriever, compute similarity
# - bm25: Use BM25 scoring on chunk text
# - random: Random chunk selection
# - oracle: Exhaustive search for chunk minimizing loss (training only)
Integration Strategies
if integrate_method == "concat":
# Prepend retrieved chunks before window
input_ids = [retrieved_chunks] + [fixed_context]
# Window: [k*128] [2048] = 2048 + k*128 tokens total
elif integrate_method == "replace":
# Replace earliest chunks in window with retrieved chunks
fixed_context = window[k*128:] # Keep last (window - k*128) tokens
input_ids = [retrieved_chunks] + [fixed_context]
# Window: [k*128] [(2048 - k*128)] = 2048 tokens total (fits in memory)
Ordering Strategies
if order_method == "sequential":
# Order by position in document (early to late)
retrieved_indices = retrieved_indices.sort(-1)[0]
elif order_method == "relevance":
# Order by relevance (most relevant first)
retrieved_indices = retrieved_indices.flip(dims=(-1,))
# Default FAISS returns descending relevance, flip to ascending
Usage Examples
Initialize SRLM
from transformers import AutoTokenizer, AutoModelForCausalLM
from FlagEmbedding import FlagModel
from research.llm_embedder.src.lm.modeling_srlm import SelfRetrievalLM
# Load base LM and retriever
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
base_model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
retriever = FlagModel('BAAI/llm-embedder', use_fp16=True)
# Wrap with SRLM
model = SelfRetrievalLM(
model=base_model,
tokenizer=tokenizer,
retriever=retriever,
context_window_size=2048, # 2K context window
chunk_size=64, # 64-token chunks
key_num=3, # Retrieve 3 chunks (3*2*64 = 384 tokens)
chunk_batch_size=8, # Process 8 target chunks per batch
retrieval_method='dense',
order_method='sequential',
integrate_method='concat',
instruction={
'query': 'Represent this text chunk for retrieval: ',
'key': 'Represent this text chunk for retrieval: '
}
)
Long Document Modeling
# Prepare long document (e.g., 32K tokens)
document = "The quick brown fox..." * 5000 # Very long text
inputs = tokenizer(document, return_tensors='pt', truncation=False)
# SRLM automatically chunks and retrieves
# Context window: 2048, but can process 32K+ tokens
outputs = model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
labels=inputs['input_ids'] # Labels for language modeling
)
print(f"Loss: {outputs.loss.item():.4f}")
# Only computes loss on target chunks, not retrieved context
# Compute perplexity
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collator)
perplexity = model.compute_perplexity(dataloader)
print(f"Perplexity: {perplexity:.2f}")
Multi-Turn Dialogue with History Retrieval
import numpy as np
# Dialogue history (list of turns)
history = np.array([
["Speaker 1: Hello, how are you?\nSpeaker 2: I'm good, thanks!"],
["Speaker 1: What's the weather?\nSpeaker 2: It's sunny today."],
["Speaker 1: Do you like sports?\nSpeaker 2: Yes, I love basketball."],
# ... many more turns
])
query = np.array(["What did we talk about earlier?"])
answer = np.array(["We discussed the weather and sports."])
# Create history mask (1 = valid turn, 0 = padding)
history_mask = torch.ones(1, history.shape[1])
# SRLM retrieves relevant history turns
outputs = model.forward_with_history_retrieval(
query=query,
history=history,
answer=answer,
history_mask=history_mask
)
# Model retrieves turns about "weather" and "sports"
# Generates answer conditioned on retrieved history
Different Retrieval Methods
# Dense retrieval (default)
model_dense = SelfRetrievalLM(retriever=retriever, retrieval_method='dense', ...)
# BM25 retrieval
from pyserini.search import LuceneSearcher
bm25_searcher = LuceneSearcher(...)
model_bm25 = SelfRetrievalLM(retriever=bm25_searcher, retrieval_method='bm25', ...)
# Random retrieval (baseline)
model_random = SelfRetrievalLM(retrieval_method='random', ...)
# No retrieval (full context)
model_no_retrieval = SelfRetrievalLM(retrieval_method='no', ...)
# Oracle retrieval (exhaustive search, training only)
model_oracle = SelfRetrievalLM(
retrieval_method='oracle',
key_num=1, # Must be 1 for oracle
debug_retrieval=True # Print selected chunks
)
Concat vs Replace Integration
# Concat: Prepend retrieved chunks (longer context)
model_concat = SelfRetrievalLM(
context_window_size=2048,
key_num=3,
integrate_method='concat'
)
# Total context: 2048 + 3*2*64 = 2432 tokens
# Replace: Replace early chunks (fixed context length)
model_replace = SelfRetrievalLM(
context_window_size=2048,
key_num=3,
integrate_method='replace'
)
# Total context: 2048 tokens (3*2*64 replaced, rest kept)
Debug Retrieval Decisions
# Enable debug mode to see what's retrieved
model = SelfRetrievalLM(
retriever=retriever,
debug_retrieval=True,
...
)
# During training/inference, prints:
# ***Indices***
# [42, 15, 89] # Retrieved chunk indices
# ***Query***
# "The fox jumped over..."
# ***Retrieved***
# ["The quick brown fox...", "Lazy dogs often...", "Foxes are mammals..."]
# ***Target***
# "the lazy dog."