Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Reinforced IR Retriever Dataset

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Information Retrieval, Deep Learning
Last Updated 2026-02-09 00:00 GMT

Overview

Custom PyTorch dataset and data collator classes for training embedding models in the Reinforced Information Retrieval framework.

Description

This module implements specialized dataset classes that extend FlagEmbedding's abstract base classes to support advanced training features:

  • Answer-Augmented Training: Includes both passages and generated answers as training signals
  • Knowledge Distillation: Supports teacher scores for both passages and answers
  • In-Batch Negatives: Optionally uses other batch samples as additional negatives
  • Sub-Batch Processing: Splits large batches into smaller chunks to fit in GPU memory
  • Same-Dataset Batching: Groups samples from the same task/dataset in each batch

The module provides four main classes:

  • IREmbedderTrainDataset: Standard dataset for random sampling
  • IREmbedderCollator: Tokenizes and pads samples into batches
  • IREmbedderSameDatasetTrainDataset: Groups samples by dataset for stable training
  • IREmbedderSameDatasetCollator: Handles same-dataset batch collation

Usage

Use these dataset classes when fine-tuning embedding models for retrieval tasks, especially when you have multiple training datasets and want to leverage answer augmentation, knowledge distillation, or same-dataset batching for improved training stability.

Code Reference

Source Location

Signature

class IREmbedderTrainDataset(AbsEmbedderTrainDataset):
    def __init__(self, args: AbsEmbedderDataArguments, tokenizer: PreTrainedTokenizer):
        """Initialize dataset with data arguments and tokenizer"""

    def __getitem__(self, item) -> Tuple:
        """Return query, answer, passages, teacher_scores for one sample"""

class IREmbedderCollator(AbsEmbedderCollator):
    query_max_len: int = 32
    passage_max_len: int = 128
    sub_batch_size: int = -1

    def __call__(self, features) -> dict:
        """Tokenize and collate batch of samples"""

class IREmbedderSameDatasetTrainDataset(AbsEmbedderSameDatasetTrainDataset):
    def __init__(
        self, args, default_batch_size: int, seed: int,
        tokenizer, process_index: int = 0, num_processes: int = 1
    ):
        """Initialize dataset for same-dataset batching"""

    def __getitem__(self, _) -> Tuple:
        """Return batch of samples from same dataset"""

Import

from research.Reinforced_IR.finetune.retriever.dataset import (
    IREmbedderTrainDataset,
    IREmbedderCollator,
    IREmbedderSameDatasetTrainDataset,
    IREmbedderSameDatasetCollator
)

I/O Contract

Dataset Input (Data File Format)

Field Type Required Description
query str Yes The search query text
pos List[str] Yes List of positive passages
neg List[str] Yes List of negative passages
answer str No Generated answer for the query
neg_answer List[str] No Negative answers for in-batch training
pos_scores List[float] No Teacher scores for positives (knowledge distillation)
neg_scores List[float] No Teacher scores for negatives
prompt str No Task-specific instruction
type str No Data type (for same-dataset batching)

Dataset Output

Name Type Description
query str Query with optional instruction prefix
answer str Answer text (or None)
passages List[str] 1 positive + (train_group_size-1) negatives
teacher_scores List[float] Optional distillation scores

Collator Output

Name Type Description
queries BatchEncoding Tokenized queries
answers BatchEncoding Tokenized answers (or None)
passages BatchEncoding Tokenized passages
teacher_scores List[float] Distillation scores (or None)
no_in_batch_neg_flag bool Whether to disable in-batch negatives

Key Features

Answer Augmentation

The dataset supports training with both passages and answers:

# In __getitem__:
query = data['query']
answer = data.get('answer', None)  # Optional answer field

# Answers can be used for:
# 1. Encoding queries (blend query + answer embeddings)
# 2. In-batch negative training (answer_inbatch)
# 3. Knowledge distillation with separate answer scores

Knowledge Distillation

Supports teacher scores for soft-label training:

if self.args.knowledge_distillation:
    # Extract scores for selected samples
    teacher_scores.append(data['pos_scores'][pos_idx])
    for neg_idx in neg_idxs:
        teacher_scores.append(data['neg_scores'][neg_idx])

    # Adjust positive score if lower than max negative
    if pos_score < max(neg_scores):
        # Keep as-is (indicates hard example)
        pass
    else:
        # Add margin to maintain ranking
        pos_score = pos_score + (max_neg_score - pos_score) * 0.2

Same-Dataset Batching

Groups samples from the same dataset/task in each batch:

class IREmbedderSameDatasetTrainDataset:
    def refresh_epoch(self):
        # Shuffle dataset order
        self.deterministic_generator.shuffle(self.datasets_inxs)

        # For each dataset, shuffle samples within it
        for dataset_inx in self.datasets_inxs:
            self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx])

            # Create batches from this dataset
            for start_index in range(0, len(dataset), batch_size):
                batch_datas.append(dataset[start:start+batch_size])

        # Shuffle all batches together
        self.deterministic_generator.shuffle(batch_datas)

Sub-Batch Processing

Splits large batches for memory efficiency:

if self.sub_batch_size > 0:
    # Split into smaller sub-batches
    q_collated = []
    for i in range(0, len(queries), sub_batch_size):
        sub_batch = queries[i:i+sub_batch_size]
        q_collated.append(self.tokenizer.pad(sub_batch, ...))

    # Returns list of sub-batches instead of single batch
    return {"queries": q_collated, ...}

Usage Examples

Basic Training Setup

from transformers import AutoTokenizer
from research.Reinforced_IR.finetune.retriever.dataset import (
    IREmbedderTrainDataset, IREmbedderCollator
)
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderDataArguments

# Setup
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5')
args = AbsEmbedderDataArguments(
    train_data='path/to/train.jsonl',
    train_group_size=8,  # 1 pos + 7 negs per query
    query_max_len=512,
    passage_max_len=512,
    knowledge_distillation=True
)

# Create dataset
train_dataset = IREmbedderTrainDataset(args, tokenizer)

# Create collator
collator = IREmbedderCollator(
    tokenizer=tokenizer,
    query_max_len=512,
    passage_max_len=512,
    sub_batch_size=-1  # No sub-batching
)

# Use with DataLoader
from torch.utils.data import DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    collate_fn=collator,
    shuffle=True
)

Same-Dataset Batching

from research.Reinforced_IR.finetune.retriever.dataset import (
    IREmbedderSameDatasetTrainDataset,
    IREmbedderSameDatasetCollator
)

# Create same-dataset dataset
train_dataset = IREmbedderSameDatasetTrainDataset(
    args=args,
    default_batch_size=32,
    seed=42,
    tokenizer=tokenizer,
    process_index=0,  # DDP rank
    num_processes=1   # Total DDP processes
)

# Must use batch_size=1 and num_workers=0
train_loader = DataLoader(
    train_dataset,
    batch_size=1,  # Dataset returns full batch
    num_workers=0,  # Avoid multi-processing
    collate_fn=IREmbedderSameDatasetCollator(tokenizer, ...)
)

# Each batch contains samples from same dataset
for batch in train_loader:
    # All queries in batch.queries are from same task
    # Enables stable in-batch negative training
    pass

Answer-Augmented Training

# Prepare data with answers
import json

data = [
    {
        "query": "What is Python?",
        "answer": "Python is a high-level programming language...",
        "pos": ["Python is a widely used programming language..."],
        "neg": ["Java is an object-oriented language...", ...],
        "neg_answer": ["Java is used for...", ...]  # For in-batch negatives
    },
    ...
]

with open('train_with_answers.jsonl', 'w') as f:
    for item in data:
        f.write(json.dumps(item) + '\n')

# Configure for answer training
args.answer_inbatch = True  # Use neg_answer for in-batch training

# The dataset will:
# 1. Include answer in query encoding
# 2. Use neg_answer as additional negatives
# 3. Support distillation with answer scores

Sub-Batch Processing for Large Batches

# For training with large batches that don't fit in memory
collator = IREmbedderCollator(
    tokenizer=tokenizer,
    query_max_len=512,
    passage_max_len=512,
    sub_batch_size=64  # Process 64 samples at a time
)

# The collator returns lists of sub-batches
batch = next(iter(train_loader))

# batch['queries'] is now a list of sub-batches
for sub_queries, sub_passages in zip(batch['queries'], batch['passages']):
    # Process each sub-batch separately
    embeddings = model(sub_queries)
    # Accumulate gradients...

Knowledge Distillation

# Prepare data with teacher scores
data_with_scores = {
    "query": "What is machine learning?",
    "pos": ["ML is a subset of AI..."],
    "neg": ["AI is intelligence shown by machines...", ...],
    "pos_scores": [0.95],  # Teacher score for positive
    "neg_scores": [0.82, 0.76, 0.71, ...]  # Teacher scores for negatives
}

# Enable distillation
args.knowledge_distillation = True

# The dataset will include teacher_scores in output
batch = next(iter(train_loader))
teacher_scores = batch['teacher_scores']  # Used for distillation loss

# In training loop:
# loss = contrastive_loss + kl_div(student_scores, teacher_scores)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment