Implementation:FlagOpen FlagEmbedding BGE M3 Modeling
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Multi-Modal Retrieval, Neural Networks, Information Retrieval |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A unified multi-functionality embedding model that combines dense, sparse, and multi-vector (ColBERT) representations in a single architecture for state-of-the-art retrieval performance.
Description
BGEM3Model implements the BGE-M3 (Multi-Functionality, Multi-Linguality, Multi-Granularity) model architecture, which uniquely combines three complementary retrieval methods: dense embeddings (traditional semantic vectors), sparse embeddings (learned term weighting similar to BM25), and ColBERT-style multi-vector representations (token-level matching). This unified approach enables the model to excel across different retrieval scenarios and languages.
The model supports both training and inference modes with specialized handling for each. During training, it implements sophisticated loss functions including standard contrastive loss, knowledge distillation from teacher models, self-distillation for ensemble learning, and optional negatives-cross-device for larger effective batch sizes. The architecture includes linear projections for ColBERT and sparse representations, vocabulary-aware sparse embedding generation, and configurable normalization and temperature scaling.
Key innovations include unified fine-tuning of all three representation types simultaneously, automatic optimization for sparse embeddings with unused token filtering, sub-batch processing for memory efficiency with long documents, and support for both CPU and distributed GPU training with gradient checkpointing.
Usage
Use this model for training multi-functional retrieval systems or for inference to generate dense, sparse, and ColBERT embeddings for queries and documents in multilingual retrieval tasks.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/BGE_M3/modeling.py
- Lines: 1-390
Signature
class BGEM3Model(nn.Module):
def __init__(
self,
model_name: str = None,
normlized: bool = True,
sentence_pooling_method: str = 'cls',
negatives_cross_device: bool = False,
temperature: float = 1.0,
enable_sub_batch: bool = True,
unified_finetuning: bool = True,
use_self_distill: bool = False,
colbert_dim: int = -1,
self_distill_start_step: int = -1,
)
def forward(
self,
query: Dict[str, Tensor] = None,
passage: Dict[str, Tensor] = None,
teacher_scores: Tensor = None,
bi_directions: bool = None
) -> EncoderOutput
def encode(
self,
features,
sub_batch_size=None
) -> Tuple[Tensor, Tensor, Tensor]
def dense_embedding(self, hidden_state, mask) -> Tensor
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True) -> Tensor
def colbert_embedding(self, last_hidden_state, mask) -> Tensor
class BGEM3ForInference(BGEM3Model):
def forward(
self,
text_input: Dict[str, Tensor] = None,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
return_sparse_embedding: bool = False
) -> dict
Import
from modeling import BGEM3Model, BGEM3ForInference
I/O Contract
Training Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | Dict[str, Tensor] | Yes | Query input_ids and attention_mask |
| passage | Dict[str, Tensor] | Yes | Passage input_ids and attention_mask |
| teacher_scores | Optional[Tensor] | No | Teacher model scores for distillation (batch_size, group_size) |
| bi_directions | Optional[bool] | No | Use bidirectional loss for parallel corpora |
Training Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Combined training loss (contrastive + distillation + ensemble) |
| q_reps | Tensor | Query representations (not used during training) |
| p_reps | Tensor | Passage representations (not used during training) |
Inference Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| text_input | Dict[str, Tensor] | Yes | Input_ids and attention_mask |
| return_dense | bool | No | Return dense embeddings (default: True) |
| return_sparse | bool | No | Return sparse embeddings (default: False) |
| return_colbert | bool | No | Return ColBERT embeddings (default: False) |
| return_sparse_embedding | bool | No | Return full sparse vectors vs weights only (default: False) |
Inference Outputs
| Name | Type | Description |
|---|---|---|
| dense_vecs | Optional[Tensor] | Dense embeddings (batch_size, hidden_size) |
| sparse_vecs | Optional[Tensor] | Sparse embeddings (batch_size, vocab_size) or weights |
| colbert_vecs | Optional[Tensor] | ColBERT embeddings (batch_size, seq_len-1, colbert_dim) |
Architecture Components
Base Model
- Uses AutoModel (typically BERT-based) as backbone
- Loads pretrained weights from model_name
- Extracts last_hidden_state for all representations
Dense Embedding
Pooling Methods:
- CLS: Uses [CLS] token representation (hidden_state[:, 0])
- Mean: Average of all token embeddings (masked)
Properties:
- Output dimension: hidden_size (e.g., 1024 for bge-m3)
- Normalized if normlized=True
- Temperature scaling applied to similarity scores
Sparse Embedding
Architecture:
- Linear layer: hidden_size → 1
- ReLU activation for non-negative weights
- Max pooling over token positions
- Scatter to vocabulary dimension
Features:
- Learned term importance weights
- Similar to BM25 but learned end-to-end
- Unused tokens (CLS, EOS, PAD, UNK) set to 0
- Memory-efficient computation during inference
Optimization (Issue #1364): During inference (self.training=False), uses scatter_reduce for efficiency:
# Training: (batch, seq_len, vocab_size) → max over seq_len
# Inference: (batch, vocab_size) with scatter_reduce
ColBERT Embedding
Architecture:
- Linear layer: hidden_size → colbert_dim
- Processes tokens 1 to seq_len (skips [CLS])
- Masked based on attention_mask
Properties:
- Default colbert_dim: hidden_size (or custom if specified)
- Normalized if normlized=True
- Enables late interaction (token-level matching)
Projection Layers
- colbert_linear: For ColBERT token embeddings
- sparse_linear: For sparse term weights
- Both initialized randomly or loaded from checkpoint
Training Methodology
Unified Fine-tuning
When unified_finetuning=True: 1. Train all three representations simultaneously 2. Dense + 0.3 * Sparse + ColBERT ensemble 3. Each representation has its own loss 4. Ensemble loss combines all three
Loss computation:
loss = (dense_loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4
Contrastive Learning
Standard Mode (no teacher scores):
- In-batch negatives
- Cross-entropy loss with targets = [0, group_size, 2*group_size, ...]
- Temperature scaling for similarity scores
Knowledge Distillation Mode (with teacher scores):
- Soft labels from teacher model
- Distillation loss: -Σ teacher_prob * log(student_prob)
- Applied to each representation separately
Self-Distillation
When use_self_distill=True and step > self_distill_start_step:
- Use ensemble scores as pseudo-labels
- Distill ensemble knowledge to individual representations
- Additional loss term: (dense_distill + 0.1 * sparse_distill + colbert_distill) / 3
- Total loss scaled by 0.5
Negatives-Cross-Device
When negatives_cross_device=True:
- Gather representations from all GPUs
- Compute loss across all devices
- Larger effective batch size
- Better negative sampling
Scoring Functions
Dense Score
scores = (q_reps @ p_reps.T) / temperature
# Dot product similarity, normalized if normlized=True
Sparse Score
scores = (q_sparse @ p_sparse.T) / temperature
# Vocabulary overlap with learned term weights
ColBERT Score
token_scores = q_colbert @ p_colbert.T # (q_tokens, p_tokens)
max_scores = max(token_scores, dim=-1) # Max over passage tokens
scores = sum(max_scores) / q_length / temperature
# MaxSim operation: for each query token, find best passage token
Ensemble Score
ensemble = dense_scores + 0.3 * sparse_scores + colbert_scores
# Weighted combination, sparse down-weighted
Memory Optimization
Sub-Batch Processing
When enable_sub_batch=True:
- Automatically compute sub_batch_size based on sequence length
- Process long sequences in chunks
- Prevents OOM errors on long documents
Mapping (sequence length → sub_batch_size):
- 6000+: 1
- 5000-6000: 2
- 4000-5000: 3
- 3000-4000: 3
- 2000-3000: 5
- 1000-2000: 9
- 512-1000: 16
- 0-512: 32
Gradient Checkpointing
model.gradient_checkpointing_enable()
Trades computation for memory by recomputing activations during backward pass.
Usage Examples
Training
from modeling import BGEM3Model
import torch
# Initialize model
model = BGEM3Model(
model_name="BAAI/bge-m3",
normlized=True,
sentence_pooling_method='cls',
negatives_cross_device=True,
temperature=0.02,
enable_sub_batch=True,
unified_finetuning=True,
use_self_distill=True,
self_distill_start_step=1000
)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Training step
query = {
"input_ids": torch.tensor([[101, 2054, 2003, ...], ...]),
"attention_mask": torch.tensor([[1, 1, 1, ...], ...])
}
passage = {
"input_ids": torch.tensor([[101, 4910, 4083, ...], ...]),
"attention_mask": torch.tensor([[1, 1, 1, ...], ...])
}
output = model(query=query, passage=passage)
loss = output.loss
loss.backward()
# With knowledge distillation
teacher_scores = torch.tensor([
[0.95, 0.3, 0.2, 0.15, 0.1, 0.05, 0.03, 0.02], # Batch 1
[0.90, 0.4, 0.25, 0.2, 0.15, 0.1, 0.05, 0.03], # Batch 2
# ... (batch_size, group_size)
])
output = model(query=query, passage=passage, teacher_scores=teacher_scores)
loss = output.loss
# Save model
model.save("/path/to/output")
# Saves: pytorch_model.bin, colbert_linear.pt, sparse_linear.pt
Inference
from modeling import BGEM3ForInference
from transformers import AutoTokenizer
# Load model for inference
model = BGEM3ForInference(
model_name="BAAI/bge-m3",
normlized=True
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
# Encode query
query = "What is machine learning?"
query_input = tokenizer(
query,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
# Get all three representations
output = model(
text_input=query_input,
return_dense=True,
return_sparse=True,
return_colbert=True
)
dense_emb = output['dense_vecs'] # (1, 1024)
sparse_emb = output['sparse_vecs'] # (1, vocab_size)
colbert_emb = output['colbert_vecs'] # (1, seq_len-1, 1024)
# Dense retrieval
doc = "Machine learning is a subset of AI"
doc_input = tokenizer(doc, return_tensors="pt", padding=True, truncation=True)
doc_output = model(text_input=doc_input, return_dense=True)
doc_dense = doc_output['dense_vecs']
dense_score = (dense_emb @ doc_dense.T).item()
print(f"Dense similarity: {dense_score:.4f}")
# Sparse retrieval
doc_output = model(text_input=doc_input, return_sparse=True)
doc_sparse = doc_output['sparse_vecs']
sparse_score = (sparse_emb * doc_sparse).sum().item()
print(f"Sparse similarity: {sparse_score:.4f}")
# ColBERT retrieval
doc_output = model(text_input=doc_input, return_colbert=True)
doc_colbert = doc_output['colbert_vecs']
# MaxSim operation
token_scores = torch.einsum('qin,pjn->qipj', colbert_emb, doc_colbert)
max_scores = token_scores.max(dim=-1)[0]
colbert_score = max_scores.sum() / colbert_emb.shape[1]
print(f"ColBERT similarity: {colbert_score:.4f}")
# Ensemble score
ensemble_score = dense_score + 0.3 * sparse_score + colbert_score
print(f"Ensemble similarity: {ensemble_score:.4f}")
Only Dense Embeddings
# Training without sparse and ColBERT
model = BGEM3Model(
model_name="BAAI/bge-base-en-v1.5",
unified_finetuning=False, # Disable sparse and ColBERT
normlized=True
)
# Only dense loss will be computed
output = model(query=query, passage=passage)
Model Variants
Training Model (BGEM3Model)
- Full training pipeline with loss computation
- Supports all training features
- Returns EncoderOutput with loss
Inference Model (BGEM3ForInference)
- Optimized for inference
- Selective embedding computation
- No loss computation
- Returns dict with requested embeddings
Checkpoint Management
Saving
model.save("/path/to/checkpoint")
Saves:
- pytorch_model.bin: Base model weights
- colbert_linear.pt: ColBERT projection weights
- sparse_linear.pt: Sparse projection weights
- config.json: Model configuration
Loading
model = BGEM3Model(model_name="/path/to/checkpoint")
Automatically loads colbert_linear and sparse_linear if present, otherwise initializes randomly.
Configuration Parameters
- model_name: Pretrained model or checkpoint path
- normlized: L2 normalize embeddings (default: True)
- sentence_pooling_method: "cls" or "mean" (default: "cls")
- negatives_cross_device: Gather negatives across GPUs (default: False)
- temperature: Similarity scaling temperature (default: 1.0)
- enable_sub_batch: Auto sub-batching for long docs (default: True)
- unified_finetuning: Train all three representations (default: True)
- use_self_distill: Enable self-distillation (default: False)
- colbert_dim: ColBERT dimension, -1 for hidden_size (default: -1)
- self_distill_start_step: When to start self-distillation (default: -1)
Performance Tips
1. Use unified_finetuning=True for best performance 2. Enable gradient checkpointing for long sequences 3. Set appropriate temperature (0.01-0.05 for contrastive learning) 4. Use negatives_cross_device for distributed training 5. Start self_distill after 1000 steps for stability 6. Enable sub_batch for documents > 2000 tokens