Implementation:FlagOpen FlagEmbedding LLM Embedder Dense Model
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Information Retrieval, Dense Retrieval |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Dense retriever implementation with dual encoder architecture supporting contrastive learning and knowledge distillation.
Description
The DenseRetriever module implements a bi-encoder architecture for dense retrieval with support for asymmetric query-document encoding. It provides full training, indexing, and search capabilities using FAISS for efficient similarity search. The model supports multiple pooling strategies (CLS, mean, dense), various similarity metrics (cosine, inner product, L2), and advanced training techniques including cross-device negatives and stable distillation.
Key features include:
- Tied or separate query/key encoders
- FAISS-based efficient indexing and search
- Cross-device negative sampling for distributed training
- Multiple loss functions (contrastive, distillation, stable distillation)
- Support for gradient checkpointing
- Memmap-based embedding storage for large corpora
- Task-specific training configurations
The model can be used for both training retrieval systems and performing inference (encoding, indexing, searching, reranking).
Usage
Use this model for training dense retrieval systems on custom data, building retrieval indexes over large document collections, or performing semantic search and reranking.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_embedder/src/retrieval/modeling_dense.py
Signature
class DenseRetriever(torch.nn.Module):
def __init__(self, query_encoder:str='BAAI/bge-base-en',
key_encoder:str='BAAI/bge-base-en',
pooling_method:List[str]=["cls"],
dense_metric:str="cos",
query_max_length:int=512,
key_max_length:int=512,
tie_encoders:bool=True,
truncation_side:str="right",
dtype:str="fp16",
cache_dir:Optional[str]=None,
cos_temperature:float=0.01,
contrastive_weight:float=0.2,
distill_weight:float=1.0,
teacher_temperature:float=1.0,
student_temperature:float=1.0,
negative_cross_device:bool=True,
stable_distill:bool=False,
accelerator:Accelerator=None)
def encode(self, inputs, field:str="key", with_grad:bool=False)
def forward(self, query, key, task, teacher_scores=None, **kwds)
def index(self, corpus: Dataset, output_dir="data/outputs",
embedding_name=None, index_factory:str="Flat",
save_index=False, load_encode=False, save_encode=False,
load_index=False, batch_size=500, metric=None)
def search(self, inputs, hits:int=10, **kwds)
def rerank(self, query, key, key_mask=None, **kwds)
def save_pretrained(self, output_dir: str, *args, **kwargs)
class FaissIndex:
def __init__(self, device)
def build(self, encoded_corpus, index_factory, metric)
def load(self, index_path)
def save(self, index_path)
def search(self, query, hits)
Import
from retrieval.modeling_dense import DenseRetriever
from accelerate import Accelerator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | Dict or str | Yes | Tokenized query inputs or raw text |
| key | Dict or str | Yes | Tokenized document inputs or raw text |
| task | str | Yes | Task identifier (qa, chat, lrlm, etc.) |
| teacher_scores | List[float] | No | Teacher scores for distillation |
| corpus | Dataset | No | Document corpus for indexing |
| inputs | Dict or str | No | Queries for search/rerank |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Training loss (contrastive + distillation) |
| embedding | torch.Tensor | Dense embeddings (batch_size, hidden_dim) |
| scores | torch.Tensor | Retrieval/reranking scores |
| indices | torch.Tensor | Top-k document indices |
Usage Examples
from retrieval.modeling_dense import DenseRetriever
from accelerate import Accelerator
from datasets import load_dataset
accelerator = Accelerator()
# Initialize dense retriever
model = DenseRetriever(
query_encoder="BAAI/llm-embedder",
key_encoder="BAAI/llm-embedder",
pooling_method=["cls"],
dense_metric="cos",
query_max_length=512,
key_max_length=512,
tie_encoders=True,
cos_temperature=0.01,
contrastive_weight=0.2,
distill_weight=1.0,
negative_cross_device=True,
accelerator=accelerator
)
# Training forward pass
batch = {
"query": tokenized_queries, # {"input_ids": ..., "attention_mask": ...}
"key": tokenized_keys, # Batch of 1 positive + N negatives per query
"task": "qa",
"teacher_scores": teacher_scores # Optional
}
outputs = model(**batch)
loss = outputs["loss"]
# Encode queries and documents
query_embeddings = model.encode("What is deep learning?", field="query")
doc_embeddings = model.encode("Deep learning is...", field="key")
# Index a corpus
corpus = load_dataset("json", data_files="corpus.json", split="train")
model.index(
corpus=corpus,
output_dir="indexes/",
index_factory="IVF1024,Flat",
save_index=True,
save_encode=True,
batch_size=512
)
# Search the index
queries = ["What is machine learning?", "How does NLP work?"]
scores, indices = model.search(queries, hits=10)
print(f"Top-10 documents: {indices}")
# Rerank retrieved candidates
rerank_scores, rerank_indices = model.rerank(
query=query_inputs,
key=candidate_inputs,
key_mask=key_masks
)
# Save model
model.save_pretrained("output/llm-embedder-finetuned")