Implementation:FlagOpen FlagEmbedding LLM Dense Retriever Modeling
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Large_Language_Models, Information_Retrieval |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Bi-encoder model for training large language models as dense retrievers with advanced contrastive learning features.
Description
BiEncoderModel for LLM-based dense retrieval extends the standard bi-encoder with features tailored for large language models:
Architecture: Uses LLMs (e.g., LLaMA, GPT) as encoders, extracting embeddings from the last token position rather than [CLS]. Supports sub-batch processing to handle memory constraints when encoding large passages through LLMs. Optional L2 normalization enables cosine similarity computation.
Training modes:
- Normal mode: Standard in-batch negatives with cross-device negative sharing for distributed training
- Custom mode: Computes local similarity only between query and its associated passages (no in-batch negatives)
Knowledge distillation: Supports distillation from teacher model scores. When teacher scores are provided, the model learns to match the teacher's ranking distribution over positives and negatives using KL divergence, in addition to the standard contrastive loss.
Memory optimization: Sub-batch encoding processes passages in smaller chunks to avoid OOM errors with large LLMs, automatically concatenating results while maintaining differentiability.
Usage
Use this for fine-tuning large language models as embedding models for retrieval tasks, especially when dealing with memory constraints or when distilling from stronger teacher models.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_dense_retriever/finetune/modeling.py
- Lines: 1-203
Signature
class BiEncoderModel(nn.Module):
def __init__(self, model: AutoModel = None, tokenizer: AutoTokenizer = None,
normlized: bool = False, negatives_cross_device: bool = False,
temperature: float = 1.0, sub_batch_size: int = -1)
def encode(self, features)
def forward(self, query, passage, messages, teacher_scores)
Import
from research.llm_dense_retriever.finetune.modeling import BiEncoderModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | AutoModel | Yes | Pre-trained LLM (LLaMA, GPT, etc.) |
| tokenizer | AutoTokenizer | Yes | Tokenizer matching the model |
| normlized | bool | No | Whether to L2-normalize embeddings |
| temperature | float | No | Temperature for similarity scaling |
| sub_batch_size | int | No | Sub-batch size for memory-efficient encoding (-1 disables) |
| query | Dict/List[Dict] | Yes | Query inputs (single dict or list for multiple) |
| passage | Dict/List[Dict] | Yes | Passage inputs |
| messages | List | Yes | ['normal'] or ['custom'] to select training mode |
| teacher_scores | List | No | Teacher model scores for distillation |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Contrastive loss with optional distillation loss |
| scores | Tensor | Similarity scores between queries and passages |
| q_reps | Tensor | Query embeddings from last token |
| p_reps | Tensor | Passage embeddings from last token |
Usage Examples
from transformers import AutoModel, AutoTokenizer
from research.llm_dense_retriever.finetune.modeling import BiEncoderModel
# Initialize model
base_model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = BiEncoderModel(
model=base_model,
tokenizer=tokenizer,
normlized=True,
negatives_cross_device=True,
temperature=0.02,
sub_batch_size=8 # Process 8 passages at a time to save memory
)
# Training with distillation
query_inputs = {"input_ids": query_ids, "attention_mask": query_mask}
passage_inputs = {"input_ids": passage_ids, "attention_mask": passage_mask}
teacher_scores = [0.95, 0.3, 0.2, 0.1, ...] # Scores for pos + negs
outputs = model(
query=query_inputs,
passage=passage_inputs,
messages=['normal'],
teacher_scores=teacher_scores
)
loss = outputs.loss # Contrastive + distillation loss
# Inference
model.eval()
with torch.no_grad():
q_emb = model.encode(query_inputs) # [batch, hidden_dim]
p_emb = model.encode(passage_inputs)