Implementation:FlagOpen FlagEmbedding Matryoshka Compensation Modeling
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Reranking, Neural Networks |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
BiEncoder model with Matryoshka representation learning and compensation mechanism for reranking.
Description
This module implements a neural reranker model that supports Matryoshka representation learning with a compensation mechanism. The model uses a bi-encoder architecture where queries and passages are encoded together, with the ability to dynamically adjust compression layers, ratios, and cutoff points during training. The compensation mechanism uses a temporary model (tmp_model) to preserve and update parameters that compensate for reduced representation dimensions.
The model extracts the logit corresponding to "Yes" token for binary relevance scoring. It supports layer-wise training with multiple cutoff layers, random compression layer selection during training, and knowledge distillation from teacher scores. The forward pass computes cross-entropy loss and optional distillation loss to improve performance on compressed representations.
Usage
Use this model for training rerankers that can efficiently operate at different representation dimensions through Matryoshka learning, with compensation to maintain performance at lower dimensions.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/finetune/compensation/modeling.py
- Lines: 1-183
Signature
class BiEncoderModel(nn.Module):
def __init__(
self,
model: None,
tmp_model: None,
tokenizer: AutoTokenizer = None,
compress_method: str = 'mean',
train_batch_size: int = 4,
cutoff_layers: List[int] = [2, 4],
compress_layers: List[int] = [6],
compress_ratios: List[int] = [2],
train_method: str = 'distill'
)
def encode(self, features, query_lengths, prompt_lengths):
"""Encode input with dynamic compression"""
def forward(self, pair, query_lengths, prompt_lengths, teacher_scores):
"""Forward pass with loss computation"""
Import
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Base transformer model for encoding |
| tmp_model | nn.Module | No | Temporary model for compensation parameters |
| tokenizer | AutoTokenizer | Yes | Tokenizer for finding "Yes" token |
| train_batch_size | int | Yes | Batch size for reshaping grouped logits |
| cutoff_layers | List[int] | Yes | Layer indices for early exit |
| compress_layers | List[int] | Yes | Layers to apply compression |
| compress_ratios | List[int] | Yes | Compression ratios to randomly sample |
| train_method | str | Yes | Training method ('distill' or other) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Combined cross-entropy and distillation loss |
| scores | Tensor | Relevance scores for query-passage pairs |
Usage Examples
from transformers import AutoModel, AutoTokenizer
# Initialize model and tokenizer
base_model = AutoModel.from_pretrained("BAAI/bge-reranker-v2-m3")
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")
# Create compensation model
model = BiEncoderModel(
model=base_model,
tmp_model=None, # Will be created if needed
tokenizer=tokenizer,
train_batch_size=4,
cutoff_layers=[2, 4, 6],
compress_layers=[8, 16, 24],
compress_ratios=[2, 4],
train_method='distill'
)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Forward pass during training
outputs = model(
pair=batch['pair'],
query_lengths=batch['query_lengths'],
prompt_lengths=batch['prompt_lengths'],
teacher_scores=batch['teacher_scores']
)
loss = outputs.loss
loss.backward()
# Save model
model.save("output_dir")