Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Matryoshka Compensation Modeling

From Leeroopedia


Knowledge Sources
Domains Information Retrieval, Reranking, Neural Networks
Last Updated 2026-02-09 00:00 GMT

Overview

BiEncoder model with Matryoshka representation learning and compensation mechanism for reranking.

Description

This module implements a neural reranker model that supports Matryoshka representation learning with a compensation mechanism. The model uses a bi-encoder architecture where queries and passages are encoded together, with the ability to dynamically adjust compression layers, ratios, and cutoff points during training. The compensation mechanism uses a temporary model (tmp_model) to preserve and update parameters that compensate for reduced representation dimensions.

The model extracts the logit corresponding to "Yes" token for binary relevance scoring. It supports layer-wise training with multiple cutoff layers, random compression layer selection during training, and knowledge distillation from teacher scores. The forward pass computes cross-entropy loss and optional distillation loss to improve performance on compressed representations.

Usage

Use this model for training rerankers that can efficiently operate at different representation dimensions through Matryoshka learning, with compensation to maintain performance at lower dimensions.

Code Reference

Source Location

Signature

class BiEncoderModel(nn.Module):
    def __init__(
        self,
        model: None,
        tmp_model: None,
        tokenizer: AutoTokenizer = None,
        compress_method: str = 'mean',
        train_batch_size: int = 4,
        cutoff_layers: List[int] = [2, 4],
        compress_layers: List[int] = [6],
        compress_ratios: List[int] = [2],
        train_method: str = 'distill'
    )

    def encode(self, features, query_lengths, prompt_lengths):
        """Encode input with dynamic compression"""

    def forward(self, pair, query_lengths, prompt_lengths, teacher_scores):
        """Forward pass with loss computation"""

Import

import torch
from torch import nn, Tensor
from transformers import AutoTokenizer

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes Base transformer model for encoding
tmp_model nn.Module No Temporary model for compensation parameters
tokenizer AutoTokenizer Yes Tokenizer for finding "Yes" token
train_batch_size int Yes Batch size for reshaping grouped logits
cutoff_layers List[int] Yes Layer indices for early exit
compress_layers List[int] Yes Layers to apply compression
compress_ratios List[int] Yes Compression ratios to randomly sample
train_method str Yes Training method ('distill' or other)

Outputs

Name Type Description
loss Tensor Combined cross-entropy and distillation loss
scores Tensor Relevance scores for query-passage pairs

Usage Examples

from transformers import AutoModel, AutoTokenizer

# Initialize model and tokenizer
base_model = AutoModel.from_pretrained("BAAI/bge-reranker-v2-m3")
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")

# Create compensation model
model = BiEncoderModel(
    model=base_model,
    tmp_model=None,  # Will be created if needed
    tokenizer=tokenizer,
    train_batch_size=4,
    cutoff_layers=[2, 4, 6],
    compress_layers=[8, 16, 24],
    compress_ratios=[2, 4],
    train_method='distill'
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Forward pass during training
outputs = model(
    pair=batch['pair'],
    query_lengths=batch['query_lengths'],
    prompt_lengths=batch['prompt_lengths'],
    teacher_scores=batch['teacher_scores']
)

loss = outputs.loss
loss.backward()

# Save model
model.save("output_dir")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment