Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLARA Finetune Modeling

From Leeroopedia
Revision as of 14:59, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/FlagOpen_FlagEmbedding_LLARA_Finetune_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains LLM Embedding, Bi-Encoder, Contrastive Learning
Last Updated 2026-02-09 00:00 GMT

Overview

Bi-encoder model architecture for LLARA fine-tuning with contrastive learning and cross-device negatives.

Description

This module implements the bi-encoder architecture for fine-tuning LLARA embeddings. It includes BiEncoderModel that extracts embeddings from the last 8 hidden states of an LLM, in-batch negative sampling with cross-device gathering for distributed training, contrastive loss using cross-entropy with temperature scaling, support for normalized embeddings and inner product similarity, and optional sub-batch processing to handle memory constraints. The model extracts embeddings from special token positions (<s1>-<s16>) and averages them to create the final representation.

Usage

Use this module when fine-tuning LLM-based embedding models like LLARA, implementing bi-encoder architectures with contrastive learning, and training with distributed cross-device negatives for better representation learning. The model is designed to work with the LLARA data loading infrastructure.

Code Reference

Source Location

Signature

@dataclass
class EncoderOutput(ModelOutput):
    q_reps: Optional[Tensor] = None
    p_reps: Optional[Tensor] = None
    loss: Optional[Tensor] = None
    scores: Optional[Tensor] = None

class BiEncoderModel(nn.Module):
    def __init__(
        self,
        model: AutoModel = None,
        tokenizer: AutoTokenizer = None,
        normlized: bool = False,
        negatives_cross_device: bool = False,
        temperature: float = 1.0,
        sub_batch_size: int = -1
    ):
        pass

    def encode(self, features):
        """Encode text into embeddings using last 8 hidden states"""

    def forward(self, query=None, passage=None):
        """Forward pass with contrastive loss"""

Import

from modeling import BiEncoderModel, EncoderOutput

I/O Contract

Inputs

Name Type Required Description
model AutoModel Yes Pre-trained LLM backbone
tokenizer AutoTokenizer Yes Tokenizer for the LLM
normlized bool No Normalize embeddings (default: False)
negatives_cross_device bool No Cross-device negative sampling (default: False)
temperature float No Temperature for contrastive loss (default: 1.0)
sub_batch_size int No Sub-batch size for memory efficiency (-1 for disabled)
query Dict or List[Dict] Yes Query input features
passage Dict or List[Dict] Yes Passage input features

Outputs

Name Type Description
q_reps Tensor Query embeddings (batch_size, hidden_dim)
p_reps Tensor Passage embeddings (batch_size * group_size, hidden_dim)
loss Tensor Contrastive loss value
scores Tensor Similarity scores (batch_size, batch_size * group_size)

Usage Examples

# Example 1: Initialize model
from transformers import AutoModel, AutoTokenizer
from modeling import BiEncoderModel

# Load base model
base_model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Create bi-encoder
model = BiEncoderModel(
    model=base_model,
    tokenizer=tokenizer,
    normlized=True,
    negatives_cross_device=True,
    temperature=0.02,
    sub_batch_size=-1
)

# Example 2: Forward pass
import torch

# Prepare inputs (from dataloader)
query = {
    'input_ids': torch.randint(0, 32000, (4, 32)),
    'attention_mask': torch.ones(4, 32)
}

passage = {
    'input_ids': torch.randint(0, 32000, (32, 128)),  # 4 queries * 8 passages
    'attention_mask': torch.ones(32, 128)
}

# Forward
output = model(query=query, passage=passage)
print(f"Loss: {output.loss.item()}")
print(f"Query embeddings: {output.q_reps.shape}")  # (4, hidden_dim)
print(f"Passage embeddings: {output.p_reps.shape}")  # (32, hidden_dim)
print(f"Scores: {output.scores.shape}")  # (4, 32)

# Example 3: Inference mode
model.eval()
with torch.no_grad():
    # Encode queries
    query_emb = model.encode(query)
    print(f"Query embeddings: {query_emb.shape}")

    # Encode passages
    passage_emb = model.encode(passage)
    print(f"Passage embeddings: {passage_emb.shape}")

    # Compute similarities
    similarities = model.compute_similarity(query_emb, passage_emb)
    print(f"Similarities: {similarities.shape}")

# Example 4: Save model
model.save("./llara_finetuned")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment