Implementation:FlagOpen FlagEmbedding Matryoshka Compensation Data
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Reranking, Data Processing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Dataset and data collator for training Matryoshka reranker models with compensation mechanism.
Description
This module implements a PyTorch Dataset and DataCollator for training reranker models with Matryoshka representation learning and compensation. It loads query-passage pairs from JSON files, with support for positive passages, hard negatives, and teacher scores for knowledge distillation. The dataset tokenizes queries and passages separately, then combines them with a separator and task prompt.
The dataset handles dynamic sampling of positive and negative passages per query, with configurable group sizes. It supports multiple datasets with size limits per dataset. The custom RerankCollator extends DataCollatorForSeq2Seq to handle the special structure of reranking data, returning query lengths, prompt lengths, and teacher scores for the compensation training mechanism.
Usage
Use this dataset class to prepare data for training Matryoshka reranker models that learn to compensate for reduced dimensionality through knowledge distillation and layer-wise training.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/finetune/compensation/data.py
- Lines: 1-231
Signature
class TrainDatasetForReranker(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PreTrainedTokenizer
)
def __getitem__(self, item) -> tuple[List[BatchEncoding], List[int], List[int], List[int]]:
"""Returns passages_inputs, query_inputs_length, prompt_inputs_length, scores"""
class RerankCollator(DataCollatorForSeq2Seq):
def __call__(self, features_lengths, return_tensors='pt'):
"""Returns collated batch with query/prompt lengths and teacher scores"""
Import
from torch.utils.data import Dataset
from transformers import DataCollatorForSeq2Seq, PreTrainedTokenizer
import datasets
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| train_data | List[str] | Yes | Paths to training data directories or files |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer for encoding text |
| query_max_len | int | Yes | Maximum length for query tokens |
| passage_max_len | int | Yes | Maximum length for passage tokens |
| train_group_size | int | Yes | Number of passages per query (1 pos + n-1 neg) |
| max_example_num_per_dataset | int | Yes | Maximum examples to use per dataset |
Outputs
| Name | Type | Description |
|---|---|---|
| pair | BatchEncoding | Collated tokenized query-passage pairs |
| query_lengths | List[int] | Lengths of query portions in each pair |
| prompt_lengths | List[int] | Lengths of task prompt portions |
| teacher_scores | List[float] | Teacher model scores for distillation |
Usage Examples
from transformers import AutoTokenizer
from arguments import DataArguments
# Initialize tokenizer and arguments
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-v2-m3")
args = DataArguments(
train_data=["data/train"],
query_max_len=256,
passage_max_len=512,
train_group_size=8,
max_example_num_per_dataset=100000
)
# Create dataset
dataset = TrainDatasetForReranker(args, tokenizer)
# Create data collator
collator = RerankCollator(
tokenizer=tokenizer,
query_max_len=args.query_max_len,
passage_max_len=args.passage_max_len
)
# Use with DataLoader
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=4,
collate_fn=collator
)
# Training data format (JSON):
# {
# "query": "What is machine learning?",
# "pos": ["Machine learning is..."],
# "pos_scores": [1.0],
# "neg": ["Unrelated passage 1", "Unrelated passage 2"],
# "neg_scores": [0.2, 0.1]
# }