Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Matryoshka Self Distillation Modeling

From Leeroopedia


Knowledge Sources
Domains Information Retrieval, Reranking, Self-Distillation
Last Updated 2026-02-09 00:00 GMT

Overview

BiEncoder model with Matryoshka learning and self-distillation for layer-wise reranker training.

Description

This module implements a sophisticated reranker model that combines Matryoshka representation learning with self-distillation. Unlike the compensation version, this model uses layer-wise self-distillation where deeper layers teach shallower layers, enabling efficient inference at various computational budgets. The model supports multiple distillation strategies: teacher-based, final-layer, last-layer, and fix-layer distillation.

During training, the model randomly selects compression layers and ratios, computing logits at multiple depths. It uses the full-resolution representation at deeper layers as teacher signals for compressed representations at shallower layers. This creates a nested hierarchy of models where each can be independently used for inference. The training supports both external teacher distillation and internal self-distillation across layers.

Usage

Use this model to train rerankers that can operate efficiently at different layer depths and representation dimensions, with each layer learning from deeper layers through self-distillation.

Code Reference

Source Location

Signature

class BiEncoderModel(nn.Module):
    def __init__(
        self,
        model: None,
        tokenizer: AutoTokenizer = None,
        compress_method: str = 'mean',
        train_batch_size: int = 4,
        cutoff_layers: List[int] = [2, 4],
        compress_layers: List[int] = [6],
        compress_ratios: List[int] = [2],
        train_method: str = 'distill'
    )

    def encode(self, features, query_lengths, prompt_lengths):
        """Encode with compression"""

    def encode_full(self, features, query_lengths, prompt_lengths):
        """Encode without compression for teacher signals"""

    def forward(self, pair, query_lengths, prompt_lengths, teacher_scores):
        """Forward with layer-wise self-distillation"""

Import

import torch
from torch import nn, Tensor
from transformers import AutoTokenizer

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes Base transformer model
tokenizer AutoTokenizer Yes Tokenizer for "Yes" token location
train_batch_size int Yes Batch size for loss computation
cutoff_layers List[int] Yes Layer indices for early exit
compress_layers List[int] Yes Layers to compress
compress_ratios List[int] Yes Compression ratios to sample from
train_method str Yes Distillation method (distill, distill_teacher, distill_final_layer, distill_last_layer, distill_fix_layer)

Outputs

Name Type Description
loss Tensor Combined loss from all layers with self-distillation
scores Tensor/List[Tensor] Relevance scores (list if layer-wise, tensor otherwise)

Usage Examples

from transformers import AutoModel, AutoTokenizer

# Initialize components
base_model = AutoModel.from_pretrained("BAAI/bge-reranker-v2-m3")
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")

# Create self-distillation model
model = BiEncoderModel(
    model=base_model,
    tokenizer=tokenizer,
    train_batch_size=4,
    cutoff_layers=[4, 8, 12],
    compress_layers=[6, 12, 18, 24],
    compress_ratios=[2, 4, 8],
    train_method='distill_last_layer'  # Layer-wise self-distillation
)

# Training forward pass
outputs = model(
    pair=batch['pair'],
    query_lengths=batch['query_lengths'],
    prompt_lengths=batch['prompt_lengths'],
    teacher_scores=batch['teacher_scores']  # Optional external teacher
)

# Loss includes:
# 1. Cross-entropy loss at each layer
# 2. Self-distillation from deeper to shallower layers
# 3. Optional external teacher distillation
loss = outputs.loss

# Inference can use any layer depth
# model.config.layer_wise = False  # For single-layer inference
# model.cutoff_layers = [8]  # Use only 8 layers

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment