Implementation:FlagOpen FlagEmbedding LLARA Finetune Modeling
| Knowledge Sources | |
|---|---|
| Domains | LLM Embedding, Bi-Encoder, Contrastive Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Bi-encoder model architecture for LLARA fine-tuning with contrastive learning and cross-device negatives.
Description
This module implements the bi-encoder architecture for fine-tuning LLARA embeddings. It includes BiEncoderModel that extracts embeddings from the last 8 hidden states of an LLM, in-batch negative sampling with cross-device gathering for distributed training, contrastive loss using cross-entropy with temperature scaling, support for normalized embeddings and inner product similarity, and optional sub-batch processing to handle memory constraints. The model extracts embeddings from special token positions (<s1>-<s16>) and averages them to create the final representation.
Usage
Use this module when fine-tuning LLM-based embedding models like LLARA, implementing bi-encoder architectures with contrastive learning, and training with distributed cross-device negatives for better representation learning. The model is designed to work with the LLARA data loading infrastructure.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/LLARA/finetune/modeling.py
- Lines: 1-165
Signature
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
def __init__(
self,
model: AutoModel = None,
tokenizer: AutoTokenizer = None,
normlized: bool = False,
negatives_cross_device: bool = False,
temperature: float = 1.0,
sub_batch_size: int = -1
):
pass
def encode(self, features):
"""Encode text into embeddings using last 8 hidden states"""
def forward(self, query=None, passage=None):
"""Forward pass with contrastive loss"""
Import
from modeling import BiEncoderModel, EncoderOutput
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | AutoModel | Yes | Pre-trained LLM backbone |
| tokenizer | AutoTokenizer | Yes | Tokenizer for the LLM |
| normlized | bool | No | Normalize embeddings (default: False) |
| negatives_cross_device | bool | No | Cross-device negative sampling (default: False) |
| temperature | float | No | Temperature for contrastive loss (default: 1.0) |
| sub_batch_size | int | No | Sub-batch size for memory efficiency (-1 for disabled) |
| query | Dict or List[Dict] | Yes | Query input features |
| passage | Dict or List[Dict] | Yes | Passage input features |
Outputs
| Name | Type | Description |
|---|---|---|
| q_reps | Tensor | Query embeddings (batch_size, hidden_dim) |
| p_reps | Tensor | Passage embeddings (batch_size * group_size, hidden_dim) |
| loss | Tensor | Contrastive loss value |
| scores | Tensor | Similarity scores (batch_size, batch_size * group_size) |
Usage Examples
# Example 1: Initialize model
from transformers import AutoModel, AutoTokenizer
from modeling import BiEncoderModel
# Load base model
base_model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Create bi-encoder
model = BiEncoderModel(
model=base_model,
tokenizer=tokenizer,
normlized=True,
negatives_cross_device=True,
temperature=0.02,
sub_batch_size=-1
)
# Example 2: Forward pass
import torch
# Prepare inputs (from dataloader)
query = {
'input_ids': torch.randint(0, 32000, (4, 32)),
'attention_mask': torch.ones(4, 32)
}
passage = {
'input_ids': torch.randint(0, 32000, (32, 128)), # 4 queries * 8 passages
'attention_mask': torch.ones(32, 128)
}
# Forward
output = model(query=query, passage=passage)
print(f"Loss: {output.loss.item()}")
print(f"Query embeddings: {output.q_reps.shape}") # (4, hidden_dim)
print(f"Passage embeddings: {output.p_reps.shape}") # (32, hidden_dim)
print(f"Scores: {output.scores.shape}") # (4, 32)
# Example 3: Inference mode
model.eval()
with torch.no_grad():
# Encode queries
query_emb = model.encode(query)
print(f"Query embeddings: {query_emb.shape}")
# Encode passages
passage_emb = model.encode(passage)
print(f"Passage embeddings: {passage_emb.shape}")
# Compute similarities
similarities = model.compute_similarity(query_emb, passage_emb)
print(f"Similarities: {similarities.shape}")
# Example 4: Save model
model.save("./llara_finetuned")