Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLM Embedder Dense Model

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Information Retrieval, Dense Retrieval
Last Updated 2026-02-09 00:00 GMT

Overview

Dense retriever implementation with dual encoder architecture supporting contrastive learning and knowledge distillation.

Description

The DenseRetriever module implements a bi-encoder architecture for dense retrieval with support for asymmetric query-document encoding. It provides full training, indexing, and search capabilities using FAISS for efficient similarity search. The model supports multiple pooling strategies (CLS, mean, dense), various similarity metrics (cosine, inner product, L2), and advanced training techniques including cross-device negatives and stable distillation.

Key features include:

  • Tied or separate query/key encoders
  • FAISS-based efficient indexing and search
  • Cross-device negative sampling for distributed training
  • Multiple loss functions (contrastive, distillation, stable distillation)
  • Support for gradient checkpointing
  • Memmap-based embedding storage for large corpora
  • Task-specific training configurations

The model can be used for both training retrieval systems and performing inference (encoding, indexing, searching, reranking).

Usage

Use this model for training dense retrieval systems on custom data, building retrieval indexes over large document collections, or performing semantic search and reranking.

Code Reference

Source Location

Signature

class DenseRetriever(torch.nn.Module):
    def __init__(self, query_encoder:str='BAAI/bge-base-en',
                 key_encoder:str='BAAI/bge-base-en',
                 pooling_method:List[str]=["cls"],
                 dense_metric:str="cos",
                 query_max_length:int=512,
                 key_max_length:int=512,
                 tie_encoders:bool=True,
                 truncation_side:str="right",
                 dtype:str="fp16",
                 cache_dir:Optional[str]=None,
                 cos_temperature:float=0.01,
                 contrastive_weight:float=0.2,
                 distill_weight:float=1.0,
                 teacher_temperature:float=1.0,
                 student_temperature:float=1.0,
                 negative_cross_device:bool=True,
                 stable_distill:bool=False,
                 accelerator:Accelerator=None)

    def encode(self, inputs, field:str="key", with_grad:bool=False)

    def forward(self, query, key, task, teacher_scores=None, **kwds)

    def index(self, corpus: Dataset, output_dir="data/outputs",
              embedding_name=None, index_factory:str="Flat",
              save_index=False, load_encode=False, save_encode=False,
              load_index=False, batch_size=500, metric=None)

    def search(self, inputs, hits:int=10, **kwds)

    def rerank(self, query, key, key_mask=None, **kwds)

    def save_pretrained(self, output_dir: str, *args, **kwargs)

class FaissIndex:
    def __init__(self, device)
    def build(self, encoded_corpus, index_factory, metric)
    def load(self, index_path)
    def save(self, index_path)
    def search(self, query, hits)

Import

from retrieval.modeling_dense import DenseRetriever
from accelerate import Accelerator

I/O Contract

Inputs

Name Type Required Description
query Dict or str Yes Tokenized query inputs or raw text
key Dict or str Yes Tokenized document inputs or raw text
task str Yes Task identifier (qa, chat, lrlm, etc.)
teacher_scores List[float] No Teacher scores for distillation
corpus Dataset No Document corpus for indexing
inputs Dict or str No Queries for search/rerank

Outputs

Name Type Description
loss torch.Tensor Training loss (contrastive + distillation)
embedding torch.Tensor Dense embeddings (batch_size, hidden_dim)
scores torch.Tensor Retrieval/reranking scores
indices torch.Tensor Top-k document indices

Usage Examples

from retrieval.modeling_dense import DenseRetriever
from accelerate import Accelerator
from datasets import load_dataset

accelerator = Accelerator()

# Initialize dense retriever
model = DenseRetriever(
    query_encoder="BAAI/llm-embedder",
    key_encoder="BAAI/llm-embedder",
    pooling_method=["cls"],
    dense_metric="cos",
    query_max_length=512,
    key_max_length=512,
    tie_encoders=True,
    cos_temperature=0.01,
    contrastive_weight=0.2,
    distill_weight=1.0,
    negative_cross_device=True,
    accelerator=accelerator
)

# Training forward pass
batch = {
    "query": tokenized_queries,  # {"input_ids": ..., "attention_mask": ...}
    "key": tokenized_keys,       # Batch of 1 positive + N negatives per query
    "task": "qa",
    "teacher_scores": teacher_scores  # Optional
}
outputs = model(**batch)
loss = outputs["loss"]

# Encode queries and documents
query_embeddings = model.encode("What is deep learning?", field="query")
doc_embeddings = model.encode("Deep learning is...", field="key")

# Index a corpus
corpus = load_dataset("json", data_files="corpus.json", split="train")
model.index(
    corpus=corpus,
    output_dir="indexes/",
    index_factory="IVF1024,Flat",
    save_index=True,
    save_encode=True,
    batch_size=512
)

# Search the index
queries = ["What is machine learning?", "How does NLP work?"]
scores, indices = model.search(queries, hits=10)
print(f"Top-10 documents: {indices}")

# Rerank retrieved candidates
rerank_scores, rerank_indices = model.rerank(
    query=query_inputs,
    key=candidate_inputs,
    key_mask=key_masks
)

# Save model
model.save_pretrained("output/llm-embedder-finetuned")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment