Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Matryoshka Compensation Data

From Leeroopedia


Knowledge Sources
Domains Information Retrieval, Reranking, Data Processing
Last Updated 2026-02-09 00:00 GMT

Overview

Dataset and data collator for training Matryoshka reranker models with compensation mechanism.

Description

This module implements a PyTorch Dataset and DataCollator for training reranker models with Matryoshka representation learning and compensation. It loads query-passage pairs from JSON files, with support for positive passages, hard negatives, and teacher scores for knowledge distillation. The dataset tokenizes queries and passages separately, then combines them with a separator and task prompt.

The dataset handles dynamic sampling of positive and negative passages per query, with configurable group sizes. It supports multiple datasets with size limits per dataset. The custom RerankCollator extends DataCollatorForSeq2Seq to handle the special structure of reranking data, returning query lengths, prompt lengths, and teacher scores for the compensation training mechanism.

Usage

Use this dataset class to prepare data for training Matryoshka reranker models that learn to compensate for reduced dimensionality through knowledge distillation and layer-wise training.

Code Reference

Source Location

Signature

class TrainDatasetForReranker(Dataset):
    def __init__(
        self,
        args: DataArguments,
        tokenizer: PreTrainedTokenizer
    )

    def __getitem__(self, item) -> tuple[List[BatchEncoding], List[int], List[int], List[int]]:
        """Returns passages_inputs, query_inputs_length, prompt_inputs_length, scores"""

class RerankCollator(DataCollatorForSeq2Seq):
    def __call__(self, features_lengths, return_tensors='pt'):
        """Returns collated batch with query/prompt lengths and teacher scores"""

Import

from torch.utils.data import Dataset
from transformers import DataCollatorForSeq2Seq, PreTrainedTokenizer
import datasets

I/O Contract

Inputs

Name Type Required Description
train_data List[str] Yes Paths to training data directories or files
tokenizer PreTrainedTokenizer Yes Tokenizer for encoding text
query_max_len int Yes Maximum length for query tokens
passage_max_len int Yes Maximum length for passage tokens
train_group_size int Yes Number of passages per query (1 pos + n-1 neg)
max_example_num_per_dataset int Yes Maximum examples to use per dataset

Outputs

Name Type Description
pair BatchEncoding Collated tokenized query-passage pairs
query_lengths List[int] Lengths of query portions in each pair
prompt_lengths List[int] Lengths of task prompt portions
teacher_scores List[float] Teacher model scores for distillation

Usage Examples

from transformers import AutoTokenizer
from arguments import DataArguments

# Initialize tokenizer and arguments
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")
args = DataArguments(
    train_data=["data/train"],
    query_max_len=256,
    passage_max_len=512,
    train_group_size=8,
    max_example_num_per_dataset=100000
)

# Create dataset
dataset = TrainDatasetForReranker(args, tokenizer)

# Create data collator
collator = RerankCollator(
    tokenizer=tokenizer,
    query_max_len=args.query_max_len,
    passage_max_len=args.passage_max_len
)

# Use with DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collator
)

# Training data format (JSON):
# {
#     "query": "What is machine learning?",
#     "pos": ["Machine learning is..."],
#     "pos_scores": [1.0],
#     "neg": ["Unrelated passage 1", "Unrelated passage 2"],
#     "neg_scores": [0.2, 0.1]
# }

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment