Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding BGE M3 Modeling

From Leeroopedia
Revision as of 14:58, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/FlagOpen_FlagEmbedding_BGE_M3_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Machine Learning, Multi-Modal Retrieval, Neural Networks, Information Retrieval
Last Updated 2026-02-09 00:00 GMT

Overview

A unified multi-functionality embedding model that combines dense, sparse, and multi-vector (ColBERT) representations in a single architecture for state-of-the-art retrieval performance.

Description

BGEM3Model implements the BGE-M3 (Multi-Functionality, Multi-Linguality, Multi-Granularity) model architecture, which uniquely combines three complementary retrieval methods: dense embeddings (traditional semantic vectors), sparse embeddings (learned term weighting similar to BM25), and ColBERT-style multi-vector representations (token-level matching). This unified approach enables the model to excel across different retrieval scenarios and languages.

The model supports both training and inference modes with specialized handling for each. During training, it implements sophisticated loss functions including standard contrastive loss, knowledge distillation from teacher models, self-distillation for ensemble learning, and optional negatives-cross-device for larger effective batch sizes. The architecture includes linear projections for ColBERT and sparse representations, vocabulary-aware sparse embedding generation, and configurable normalization and temperature scaling.

Key innovations include unified fine-tuning of all three representation types simultaneously, automatic optimization for sparse embeddings with unused token filtering, sub-batch processing for memory efficiency with long documents, and support for both CPU and distributed GPU training with gradient checkpointing.

Usage

Use this model for training multi-functional retrieval systems or for inference to generate dense, sparse, and ColBERT embeddings for queries and documents in multilingual retrieval tasks.

Code Reference

Source Location

Signature

class BGEM3Model(nn.Module):
    def __init__(
        self,
        model_name: str = None,
        normlized: bool = True,
        sentence_pooling_method: str = 'cls',
        negatives_cross_device: bool = False,
        temperature: float = 1.0,
        enable_sub_batch: bool = True,
        unified_finetuning: bool = True,
        use_self_distill: bool = False,
        colbert_dim: int = -1,
        self_distill_start_step: int = -1,
    )

    def forward(
        self,
        query: Dict[str, Tensor] = None,
        passage: Dict[str, Tensor] = None,
        teacher_scores: Tensor = None,
        bi_directions: bool = None
    ) -> EncoderOutput

    def encode(
        self,
        features,
        sub_batch_size=None
    ) -> Tuple[Tensor, Tensor, Tensor]

    def dense_embedding(self, hidden_state, mask) -> Tensor
    def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True) -> Tensor
    def colbert_embedding(self, last_hidden_state, mask) -> Tensor

class BGEM3ForInference(BGEM3Model):
    def forward(
        self,
        text_input: Dict[str, Tensor] = None,
        return_dense: bool = True,
        return_sparse: bool = False,
        return_colbert: bool = False,
        return_sparse_embedding: bool = False
    ) -> dict

Import

from modeling import BGEM3Model, BGEM3ForInference

I/O Contract

Training Inputs

Name Type Required Description
query Dict[str, Tensor] Yes Query input_ids and attention_mask
passage Dict[str, Tensor] Yes Passage input_ids and attention_mask
teacher_scores Optional[Tensor] No Teacher model scores for distillation (batch_size, group_size)
bi_directions Optional[bool] No Use bidirectional loss for parallel corpora

Training Outputs

Name Type Description
loss Tensor Combined training loss (contrastive + distillation + ensemble)
q_reps Tensor Query representations (not used during training)
p_reps Tensor Passage representations (not used during training)

Inference Inputs

Name Type Required Description
text_input Dict[str, Tensor] Yes Input_ids and attention_mask
return_dense bool No Return dense embeddings (default: True)
return_sparse bool No Return sparse embeddings (default: False)
return_colbert bool No Return ColBERT embeddings (default: False)
return_sparse_embedding bool No Return full sparse vectors vs weights only (default: False)

Inference Outputs

Name Type Description
dense_vecs Optional[Tensor] Dense embeddings (batch_size, hidden_size)
sparse_vecs Optional[Tensor] Sparse embeddings (batch_size, vocab_size) or weights
colbert_vecs Optional[Tensor] ColBERT embeddings (batch_size, seq_len-1, colbert_dim)

Architecture Components

Base Model

  • Uses AutoModel (typically BERT-based) as backbone
  • Loads pretrained weights from model_name
  • Extracts last_hidden_state for all representations

Dense Embedding

Pooling Methods:

  • CLS: Uses [CLS] token representation (hidden_state[:, 0])
  • Mean: Average of all token embeddings (masked)

Properties:

  • Output dimension: hidden_size (e.g., 1024 for bge-m3)
  • Normalized if normlized=True
  • Temperature scaling applied to similarity scores

Sparse Embedding

Architecture:

  • Linear layer: hidden_size → 1
  • ReLU activation for non-negative weights
  • Max pooling over token positions
  • Scatter to vocabulary dimension

Features:

  • Learned term importance weights
  • Similar to BM25 but learned end-to-end
  • Unused tokens (CLS, EOS, PAD, UNK) set to 0
  • Memory-efficient computation during inference

Optimization (Issue #1364): During inference (self.training=False), uses scatter_reduce for efficiency:

# Training: (batch, seq_len, vocab_size) → max over seq_len
# Inference: (batch, vocab_size) with scatter_reduce

ColBERT Embedding

Architecture:

  • Linear layer: hidden_size → colbert_dim
  • Processes tokens 1 to seq_len (skips [CLS])
  • Masked based on attention_mask

Properties:

  • Default colbert_dim: hidden_size (or custom if specified)
  • Normalized if normlized=True
  • Enables late interaction (token-level matching)

Projection Layers

  • colbert_linear: For ColBERT token embeddings
  • sparse_linear: For sparse term weights
  • Both initialized randomly or loaded from checkpoint

Training Methodology

Unified Fine-tuning

When unified_finetuning=True: 1. Train all three representations simultaneously 2. Dense + 0.3 * Sparse + ColBERT ensemble 3. Each representation has its own loss 4. Ensemble loss combines all three

Loss computation:

loss = (dense_loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4

Contrastive Learning

Standard Mode (no teacher scores):

  • In-batch negatives
  • Cross-entropy loss with targets = [0, group_size, 2*group_size, ...]
  • Temperature scaling for similarity scores

Knowledge Distillation Mode (with teacher scores):

  • Soft labels from teacher model
  • Distillation loss: -Σ teacher_prob * log(student_prob)
  • Applied to each representation separately

Self-Distillation

When use_self_distill=True and step > self_distill_start_step:

  • Use ensemble scores as pseudo-labels
  • Distill ensemble knowledge to individual representations
  • Additional loss term: (dense_distill + 0.1 * sparse_distill + colbert_distill) / 3
  • Total loss scaled by 0.5

Negatives-Cross-Device

When negatives_cross_device=True:

  • Gather representations from all GPUs
  • Compute loss across all devices
  • Larger effective batch size
  • Better negative sampling

Scoring Functions

Dense Score

scores = (q_reps @ p_reps.T) / temperature
# Dot product similarity, normalized if normlized=True

Sparse Score

scores = (q_sparse @ p_sparse.T) / temperature
# Vocabulary overlap with learned term weights

ColBERT Score

token_scores = q_colbert @ p_colbert.T  # (q_tokens, p_tokens)
max_scores = max(token_scores, dim=-1)  # Max over passage tokens
scores = sum(max_scores) / q_length / temperature
# MaxSim operation: for each query token, find best passage token

Ensemble Score

ensemble = dense_scores + 0.3 * sparse_scores + colbert_scores
# Weighted combination, sparse down-weighted

Memory Optimization

Sub-Batch Processing

When enable_sub_batch=True:

  • Automatically compute sub_batch_size based on sequence length
  • Process long sequences in chunks
  • Prevents OOM errors on long documents

Mapping (sequence length → sub_batch_size):

  • 6000+: 1
  • 5000-6000: 2
  • 4000-5000: 3
  • 3000-4000: 3
  • 2000-3000: 5
  • 1000-2000: 9
  • 512-1000: 16
  • 0-512: 32

Gradient Checkpointing

model.gradient_checkpointing_enable()

Trades computation for memory by recomputing activations during backward pass.

Usage Examples

Training

from modeling import BGEM3Model
import torch

# Initialize model
model = BGEM3Model(
    model_name="BAAI/bge-m3",
    normlized=True,
    sentence_pooling_method='cls',
    negatives_cross_device=True,
    temperature=0.02,
    enable_sub_batch=True,
    unified_finetuning=True,
    use_self_distill=True,
    self_distill_start_step=1000
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Training step
query = {
    "input_ids": torch.tensor([[101, 2054, 2003, ...], ...]),
    "attention_mask": torch.tensor([[1, 1, 1, ...], ...])
}
passage = {
    "input_ids": torch.tensor([[101, 4910, 4083, ...], ...]),
    "attention_mask": torch.tensor([[1, 1, 1, ...], ...])
}

output = model(query=query, passage=passage)
loss = output.loss
loss.backward()

# With knowledge distillation
teacher_scores = torch.tensor([
    [0.95, 0.3, 0.2, 0.15, 0.1, 0.05, 0.03, 0.02],  # Batch 1
    [0.90, 0.4, 0.25, 0.2, 0.15, 0.1, 0.05, 0.03],  # Batch 2
    # ... (batch_size, group_size)
])

output = model(query=query, passage=passage, teacher_scores=teacher_scores)
loss = output.loss

# Save model
model.save("/path/to/output")
# Saves: pytorch_model.bin, colbert_linear.pt, sparse_linear.pt

Inference

from modeling import BGEM3ForInference
from transformers import AutoTokenizer

# Load model for inference
model = BGEM3ForInference(
    model_name="BAAI/bge-m3",
    normlized=True
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")

# Encode query
query = "What is machine learning?"
query_input = tokenizer(
    query,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512
)

# Get all three representations
output = model(
    text_input=query_input,
    return_dense=True,
    return_sparse=True,
    return_colbert=True
)

dense_emb = output['dense_vecs']      # (1, 1024)
sparse_emb = output['sparse_vecs']    # (1, vocab_size)
colbert_emb = output['colbert_vecs']  # (1, seq_len-1, 1024)

# Dense retrieval
doc = "Machine learning is a subset of AI"
doc_input = tokenizer(doc, return_tensors="pt", padding=True, truncation=True)
doc_output = model(text_input=doc_input, return_dense=True)
doc_dense = doc_output['dense_vecs']

dense_score = (dense_emb @ doc_dense.T).item()
print(f"Dense similarity: {dense_score:.4f}")

# Sparse retrieval
doc_output = model(text_input=doc_input, return_sparse=True)
doc_sparse = doc_output['sparse_vecs']

sparse_score = (sparse_emb * doc_sparse).sum().item()
print(f"Sparse similarity: {sparse_score:.4f}")

# ColBERT retrieval
doc_output = model(text_input=doc_input, return_colbert=True)
doc_colbert = doc_output['colbert_vecs']

# MaxSim operation
token_scores = torch.einsum('qin,pjn->qipj', colbert_emb, doc_colbert)
max_scores = token_scores.max(dim=-1)[0]
colbert_score = max_scores.sum() / colbert_emb.shape[1]
print(f"ColBERT similarity: {colbert_score:.4f}")

# Ensemble score
ensemble_score = dense_score + 0.3 * sparse_score + colbert_score
print(f"Ensemble similarity: {ensemble_score:.4f}")

Only Dense Embeddings

# Training without sparse and ColBERT
model = BGEM3Model(
    model_name="BAAI/bge-base-en-v1.5",
    unified_finetuning=False,  # Disable sparse and ColBERT
    normlized=True
)

# Only dense loss will be computed
output = model(query=query, passage=passage)

Model Variants

Training Model (BGEM3Model)

  • Full training pipeline with loss computation
  • Supports all training features
  • Returns EncoderOutput with loss

Inference Model (BGEM3ForInference)

  • Optimized for inference
  • Selective embedding computation
  • No loss computation
  • Returns dict with requested embeddings

Checkpoint Management

Saving

model.save("/path/to/checkpoint")

Saves:

  • pytorch_model.bin: Base model weights
  • colbert_linear.pt: ColBERT projection weights
  • sparse_linear.pt: Sparse projection weights
  • config.json: Model configuration

Loading

model = BGEM3Model(model_name="/path/to/checkpoint")

Automatically loads colbert_linear and sparse_linear if present, otherwise initializes randomly.

Configuration Parameters

  • model_name: Pretrained model or checkpoint path
  • normlized: L2 normalize embeddings (default: True)
  • sentence_pooling_method: "cls" or "mean" (default: "cls")
  • negatives_cross_device: Gather negatives across GPUs (default: False)
  • temperature: Similarity scaling temperature (default: 1.0)
  • enable_sub_batch: Auto sub-batching for long docs (default: True)
  • unified_finetuning: Train all three representations (default: True)
  • use_self_distill: Enable self-distillation (default: False)
  • colbert_dim: ColBERT dimension, -1 for hidden_size (default: -1)
  • self_distill_start_step: When to start self-distillation (default: -1)

Performance Tips

1. Use unified_finetuning=True for best performance 2. Enable gradient checkpointing for long sequences 3. Set appropriate temperature (0.01-0.05 for contrastive learning) 4. Use negatives_cross_device for distributed training 5. Start self_distill after 1000 steps for stability 6. Enable sub_batch for documents > 2000 tokens

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment