Implementation:FlagOpen FlagEmbedding Matryoshka Rank Model
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Neural Reranking, Matryoshka Learning, Multi-Layer Inference |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A high-level reranker interface for Matryoshka models that supports multi-layer inference, token compression, and flexible early-exit strategies.
Description
The MatroyshkaReranker class provides a user-friendly interface for Matryoshka reranking models built on Mistral architecture. It extends the AbsReranker base class from FlagEmbedding and handles all the complexity of model loading, tokenization, batching, and multi-layer inference. The implementation supports loading models from both raw Mistral checkpoints (with automatic head initialization) and fine-tuned Matryoshka models with PEFT adapters. Key features include automatic batch size adjustment based on GPU memory, intelligent token compression to reduce computational costs, and the ability to extract scores from multiple layers simultaneously. The reranker automatically handles query-passage formatting, manages attention masks for compressed sequences, and provides normalized scores via sigmoid activation.
Usage
Use this class as the main interface for deploying Matryoshka rerankers in production. It simplifies inference by handling model loading, tokenization, batching, compression, and multi-layer score extraction with a simple API.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/inference/rank_model.py
- Lines: 1-394
Signature
class MatroyshkaReranker(AbsReranker):
def __init__(
self,
model_name_or_path: str,
peft_path: Optional[List[str]] = None,
use_fp16: bool = False,
use_bf16: bool = False,
cutoff_layers: Optional[List[int]] = None,
compress_layers: List[int] = [8],
compress_ratio: int = 1,
prompt: Optional[str] = None,
batch_size: int = 128,
query_max_length: Optional[int] = None,
max_length: int = 512,
normalize: bool = False,
from_raw: bool = False,
start_layer: int = 4,
**kwargs
):
# Initialize reranker with compression and layer-wise settings
@torch.no_grad()
def compute_score_single_gpu(
self,
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
batch_size: Optional[int] = None,
cutoff_layers: Optional[List[int]] = None,
compress_layers: Optional[List[int]] = None,
compress_ratio: Optional[int] = None,
**kwargs
) -> List[float]:
# Compute relevance scores with optional multi-layer output
Import
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from FlagEmbedding.abc.inference import AbsReranker
from FlagEmbedding.inference.reranker.encoder_only.base import sigmoid
from mistral_model import CostWiseMistralForCausalLM, CostWiseHead
from mistral_config import CostWiseMistralConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_name_or_path | str | Yes | Path to base model or HuggingFace model name |
| peft_path | List[str] | No | Paths to PEFT adapters to merge into model |
| sentence_pairs | List[Tuple[str, str]] | Yes | Query-passage pairs to rerank |
| cutoff_layers | List[int] | No | Layers from which to extract scores (e.g., [8, 16, 32]) |
| compress_layers | List[int] | No | Layers at which to compress tokens (default: [8]) |
| compress_ratio | int | No | Token compression ratio (1, 2, 4, or 8; default: 1) |
| batch_size | int | No | Number of pairs per batch (default: 128, auto-adjusted) |
| max_length | int | No | Maximum sequence length (default: 512) |
| normalize | bool | No | Apply sigmoid normalization to scores (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| scores | List[float] or List[List[float]] | Relevance scores for each pair; if multiple cutoff_layers, returns list of score lists |
Usage Examples
from rank_model import MatroyshkaReranker
# Initialize reranker from fine-tuned checkpoint
reranker = MatroyshkaReranker(
model_name_or_path='path/to/mistral-7b',
peft_path=['path/to/peft/adapter'],
use_bf16=True,
cutoff_layers=[8, 16, 24, 32], # Multi-layer inference
compress_layers=[8],
compress_ratio=4,
max_length=512,
batch_size=64,
normalize=True
)
# Rerank query-passage pairs
query = "What is the capital of France?"
passages = [
"Paris is the capital and largest city of France.",
"London is the capital of England.",
"Berlin is the capital of Germany."
]
sentence_pairs = [(query, passage) for passage in passages]
# Get scores from all layers
all_layer_scores = reranker.compute_score(sentence_pairs)
# all_layer_scores[0] = scores from layer 8 (fastest)
# all_layer_scores[1] = scores from layer 16
# all_layer_scores[2] = scores from layer 24
# all_layer_scores[3] = scores from layer 32 (most accurate)
# Use layer 8 for fast inference
fast_scores = all_layer_scores[0]
ranked_indices = sorted(range(len(fast_scores)), key=lambda i: fast_scores[i], reverse=True)
print("Fast ranking (layer 8):")
for idx in ranked_indices:
print(f"{passages[idx]}: {fast_scores[idx]:.4f}")
# Initialize from raw Mistral model
reranker_raw = MatroyshkaReranker(
model_name_or_path='mistralai/Mistral-7B-v0.1',
from_raw=True,
start_layer=4,
use_bf16=True,
cutoff_layers=[8],
compress_ratio=1
)