Implementation:FlagOpen FlagEmbedding Matryoshka Self Distillation Modeling
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Reranking, Self-Distillation |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
BiEncoder model with Matryoshka learning and self-distillation for layer-wise reranker training.
Description
This module implements a sophisticated reranker model that combines Matryoshka representation learning with self-distillation. Unlike the compensation version, this model uses layer-wise self-distillation where deeper layers teach shallower layers, enabling efficient inference at various computational budgets. The model supports multiple distillation strategies: teacher-based, final-layer, last-layer, and fix-layer distillation.
During training, the model randomly selects compression layers and ratios, computing logits at multiple depths. It uses the full-resolution representation at deeper layers as teacher signals for compressed representations at shallower layers. This creates a nested hierarchy of models where each can be independently used for inference. The training supports both external teacher distillation and internal self-distillation across layers.
Usage
Use this model to train rerankers that can operate efficiently at different layer depths and representation dimensions, with each layer learning from deeper layers through self-distillation.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/finetune/self_distillation/modeling.py
- Lines: 1-216
Signature
class BiEncoderModel(nn.Module):
def __init__(
self,
model: None,
tokenizer: AutoTokenizer = None,
compress_method: str = 'mean',
train_batch_size: int = 4,
cutoff_layers: List[int] = [2, 4],
compress_layers: List[int] = [6],
compress_ratios: List[int] = [2],
train_method: str = 'distill'
)
def encode(self, features, query_lengths, prompt_lengths):
"""Encode with compression"""
def encode_full(self, features, query_lengths, prompt_lengths):
"""Encode without compression for teacher signals"""
def forward(self, pair, query_lengths, prompt_lengths, teacher_scores):
"""Forward with layer-wise self-distillation"""
Import
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | Base transformer model |
| tokenizer | AutoTokenizer | Yes | Tokenizer for "Yes" token location |
| train_batch_size | int | Yes | Batch size for loss computation |
| cutoff_layers | List[int] | Yes | Layer indices for early exit |
| compress_layers | List[int] | Yes | Layers to compress |
| compress_ratios | List[int] | Yes | Compression ratios to sample from |
| train_method | str | Yes | Distillation method (distill, distill_teacher, distill_final_layer, distill_last_layer, distill_fix_layer) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Combined loss from all layers with self-distillation |
| scores | Tensor/List[Tensor] | Relevance scores (list if layer-wise, tensor otherwise) |
Usage Examples
from transformers import AutoModel, AutoTokenizer
# Initialize components
base_model = AutoModel.from_pretrained("BAAI/bge-reranker-v2-m3")
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")
# Create self-distillation model
model = BiEncoderModel(
model=base_model,
tokenizer=tokenizer,
train_batch_size=4,
cutoff_layers=[4, 8, 12],
compress_layers=[6, 12, 18, 24],
compress_ratios=[2, 4, 8],
train_method='distill_last_layer' # Layer-wise self-distillation
)
# Training forward pass
outputs = model(
pair=batch['pair'],
query_lengths=batch['query_lengths'],
prompt_lengths=batch['prompt_lengths'],
teacher_scores=batch['teacher_scores'] # Optional external teacher
)
# Loss includes:
# 1. Cross-entropy loss at each layer
# 2. Self-distillation from deeper to shallower layers
# 3. Optional external teacher distillation
loss = outputs.loss
# Inference can use any layer depth
# model.config.layer_wise = False # For single-layer inference
# model.cutoff_layers = [8] # Use only 8 layers