Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLM Dense Retriever Modeling

From Leeroopedia
Revision as of 14:59, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/FlagOpen_FlagEmbedding_LLM_Dense_Retriever_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Machine_Learning, Large_Language_Models, Information_Retrieval
Last Updated 2026-02-09 00:00 GMT

Overview

Bi-encoder model for training large language models as dense retrievers with advanced contrastive learning features.

Description

BiEncoderModel for LLM-based dense retrieval extends the standard bi-encoder with features tailored for large language models:

Architecture: Uses LLMs (e.g., LLaMA, GPT) as encoders, extracting embeddings from the last token position rather than [CLS]. Supports sub-batch processing to handle memory constraints when encoding large passages through LLMs. Optional L2 normalization enables cosine similarity computation.

Training modes:

  • Normal mode: Standard in-batch negatives with cross-device negative sharing for distributed training
  • Custom mode: Computes local similarity only between query and its associated passages (no in-batch negatives)

Knowledge distillation: Supports distillation from teacher model scores. When teacher scores are provided, the model learns to match the teacher's ranking distribution over positives and negatives using KL divergence, in addition to the standard contrastive loss.

Memory optimization: Sub-batch encoding processes passages in smaller chunks to avoid OOM errors with large LLMs, automatically concatenating results while maintaining differentiability.

Usage

Use this for fine-tuning large language models as embedding models for retrieval tasks, especially when dealing with memory constraints or when distilling from stronger teacher models.

Code Reference

Source Location

Signature

class BiEncoderModel(nn.Module):
    def __init__(self, model: AutoModel = None, tokenizer: AutoTokenizer = None,
                 normlized: bool = False, negatives_cross_device: bool = False,
                 temperature: float = 1.0, sub_batch_size: int = -1)

    def encode(self, features)
    def forward(self, query, passage, messages, teacher_scores)

Import

from research.llm_dense_retriever.finetune.modeling import BiEncoderModel

I/O Contract

Inputs

Name Type Required Description
model AutoModel Yes Pre-trained LLM (LLaMA, GPT, etc.)
tokenizer AutoTokenizer Yes Tokenizer matching the model
normlized bool No Whether to L2-normalize embeddings
temperature float No Temperature for similarity scaling
sub_batch_size int No Sub-batch size for memory-efficient encoding (-1 disables)
query Dict/List[Dict] Yes Query inputs (single dict or list for multiple)
passage Dict/List[Dict] Yes Passage inputs
messages List Yes ['normal'] or ['custom'] to select training mode
teacher_scores List No Teacher model scores for distillation

Outputs

Name Type Description
loss Tensor Contrastive loss with optional distillation loss
scores Tensor Similarity scores between queries and passages
q_reps Tensor Query embeddings from last token
p_reps Tensor Passage embeddings from last token

Usage Examples

from transformers import AutoModel, AutoTokenizer
from research.llm_dense_retriever.finetune.modeling import BiEncoderModel

# Initialize model
base_model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

model = BiEncoderModel(
    model=base_model,
    tokenizer=tokenizer,
    normlized=True,
    negatives_cross_device=True,
    temperature=0.02,
    sub_batch_size=8  # Process 8 passages at a time to save memory
)

# Training with distillation
query_inputs = {"input_ids": query_ids, "attention_mask": query_mask}
passage_inputs = {"input_ids": passage_ids, "attention_mask": passage_mask}
teacher_scores = [0.95, 0.3, 0.2, 0.1, ...]  # Scores for pos + negs

outputs = model(
    query=query_inputs,
    passage=passage_inputs,
    messages=['normal'],
    teacher_scores=teacher_scores
)
loss = outputs.loss  # Contrastive + distillation loss

# Inference
model.eval()
with torch.no_grad():
    q_emb = model.encode(query_inputs)  # [batch, hidden_dim]
    p_emb = model.encode(passage_inputs)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment