Implementation:FlagOpen FlagEmbedding Reinforced IR Retriever Dataset
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Information Retrieval, Deep Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Custom PyTorch dataset and data collator classes for training embedding models in the Reinforced Information Retrieval framework.
Description
This module implements specialized dataset classes that extend FlagEmbedding's abstract base classes to support advanced training features:
- Answer-Augmented Training: Includes both passages and generated answers as training signals
- Knowledge Distillation: Supports teacher scores for both passages and answers
- In-Batch Negatives: Optionally uses other batch samples as additional negatives
- Sub-Batch Processing: Splits large batches into smaller chunks to fit in GPU memory
- Same-Dataset Batching: Groups samples from the same task/dataset in each batch
The module provides four main classes:
- IREmbedderTrainDataset: Standard dataset for random sampling
- IREmbedderCollator: Tokenizes and pads samples into batches
- IREmbedderSameDatasetTrainDataset: Groups samples by dataset for stable training
- IREmbedderSameDatasetCollator: Handles same-dataset batch collation
Usage
Use these dataset classes when fine-tuning embedding models for retrieval tasks, especially when you have multiple training datasets and want to leverage answer augmentation, knowledge distillation, or same-dataset batching for improved training stability.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Reinforced_IR/finetune/retriever/dataset.py
- Lines: 1-533
Signature
class IREmbedderTrainDataset(AbsEmbedderTrainDataset):
def __init__(self, args: AbsEmbedderDataArguments, tokenizer: PreTrainedTokenizer):
"""Initialize dataset with data arguments and tokenizer"""
def __getitem__(self, item) -> Tuple:
"""Return query, answer, passages, teacher_scores for one sample"""
class IREmbedderCollator(AbsEmbedderCollator):
query_max_len: int = 32
passage_max_len: int = 128
sub_batch_size: int = -1
def __call__(self, features) -> dict:
"""Tokenize and collate batch of samples"""
class IREmbedderSameDatasetTrainDataset(AbsEmbedderSameDatasetTrainDataset):
def __init__(
self, args, default_batch_size: int, seed: int,
tokenizer, process_index: int = 0, num_processes: int = 1
):
"""Initialize dataset for same-dataset batching"""
def __getitem__(self, _) -> Tuple:
"""Return batch of samples from same dataset"""
Import
from research.Reinforced_IR.finetune.retriever.dataset import (
IREmbedderTrainDataset,
IREmbedderCollator,
IREmbedderSameDatasetTrainDataset,
IREmbedderSameDatasetCollator
)
I/O Contract
Dataset Input (Data File Format)
| Field | Type | Required | Description |
|---|---|---|---|
| query | str | Yes | The search query text |
| pos | List[str] | Yes | List of positive passages |
| neg | List[str] | Yes | List of negative passages |
| answer | str | No | Generated answer for the query |
| neg_answer | List[str] | No | Negative answers for in-batch training |
| pos_scores | List[float] | No | Teacher scores for positives (knowledge distillation) |
| neg_scores | List[float] | No | Teacher scores for negatives |
| prompt | str | No | Task-specific instruction |
| type | str | No | Data type (for same-dataset batching) |
Dataset Output
| Name | Type | Description |
|---|---|---|
| query | str | Query with optional instruction prefix |
| answer | str | Answer text (or None) |
| passages | List[str] | 1 positive + (train_group_size-1) negatives |
| teacher_scores | List[float] | Optional distillation scores |
Collator Output
| Name | Type | Description |
|---|---|---|
| queries | BatchEncoding | Tokenized queries |
| answers | BatchEncoding | Tokenized answers (or None) |
| passages | BatchEncoding | Tokenized passages |
| teacher_scores | List[float] | Distillation scores (or None) |
| no_in_batch_neg_flag | bool | Whether to disable in-batch negatives |
Key Features
Answer Augmentation
The dataset supports training with both passages and answers:
# In __getitem__:
query = data['query']
answer = data.get('answer', None) # Optional answer field
# Answers can be used for:
# 1. Encoding queries (blend query + answer embeddings)
# 2. In-batch negative training (answer_inbatch)
# 3. Knowledge distillation with separate answer scores
Knowledge Distillation
Supports teacher scores for soft-label training:
if self.args.knowledge_distillation:
# Extract scores for selected samples
teacher_scores.append(data['pos_scores'][pos_idx])
for neg_idx in neg_idxs:
teacher_scores.append(data['neg_scores'][neg_idx])
# Adjust positive score if lower than max negative
if pos_score < max(neg_scores):
# Keep as-is (indicates hard example)
pass
else:
# Add margin to maintain ranking
pos_score = pos_score + (max_neg_score - pos_score) * 0.2
Same-Dataset Batching
Groups samples from the same dataset/task in each batch:
class IREmbedderSameDatasetTrainDataset:
def refresh_epoch(self):
# Shuffle dataset order
self.deterministic_generator.shuffle(self.datasets_inxs)
# For each dataset, shuffle samples within it
for dataset_inx in self.datasets_inxs:
self.deterministic_generator.shuffle(self.each_data_inxs[dataset_inx])
# Create batches from this dataset
for start_index in range(0, len(dataset), batch_size):
batch_datas.append(dataset[start:start+batch_size])
# Shuffle all batches together
self.deterministic_generator.shuffle(batch_datas)
Sub-Batch Processing
Splits large batches for memory efficiency:
if self.sub_batch_size > 0:
# Split into smaller sub-batches
q_collated = []
for i in range(0, len(queries), sub_batch_size):
sub_batch = queries[i:i+sub_batch_size]
q_collated.append(self.tokenizer.pad(sub_batch, ...))
# Returns list of sub-batches instead of single batch
return {"queries": q_collated, ...}
Usage Examples
Basic Training Setup
from transformers import AutoTokenizer
from research.Reinforced_IR.finetune.retriever.dataset import (
IREmbedderTrainDataset, IREmbedderCollator
)
from FlagEmbedding.abc.finetune.embedder import AbsEmbedderDataArguments
# Setup
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-base-en-v1.5')
args = AbsEmbedderDataArguments(
train_data='path/to/train.jsonl',
train_group_size=8, # 1 pos + 7 negs per query
query_max_len=512,
passage_max_len=512,
knowledge_distillation=True
)
# Create dataset
train_dataset = IREmbedderTrainDataset(args, tokenizer)
# Create collator
collator = IREmbedderCollator(
tokenizer=tokenizer,
query_max_len=512,
passage_max_len=512,
sub_batch_size=-1 # No sub-batching
)
# Use with DataLoader
from torch.utils.data import DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=32,
collate_fn=collator,
shuffle=True
)
Same-Dataset Batching
from research.Reinforced_IR.finetune.retriever.dataset import (
IREmbedderSameDatasetTrainDataset,
IREmbedderSameDatasetCollator
)
# Create same-dataset dataset
train_dataset = IREmbedderSameDatasetTrainDataset(
args=args,
default_batch_size=32,
seed=42,
tokenizer=tokenizer,
process_index=0, # DDP rank
num_processes=1 # Total DDP processes
)
# Must use batch_size=1 and num_workers=0
train_loader = DataLoader(
train_dataset,
batch_size=1, # Dataset returns full batch
num_workers=0, # Avoid multi-processing
collate_fn=IREmbedderSameDatasetCollator(tokenizer, ...)
)
# Each batch contains samples from same dataset
for batch in train_loader:
# All queries in batch.queries are from same task
# Enables stable in-batch negative training
pass
Answer-Augmented Training
# Prepare data with answers
import json
data = [
{
"query": "What is Python?",
"answer": "Python is a high-level programming language...",
"pos": ["Python is a widely used programming language..."],
"neg": ["Java is an object-oriented language...", ...],
"neg_answer": ["Java is used for...", ...] # For in-batch negatives
},
...
]
with open('train_with_answers.jsonl', 'w') as f:
for item in data:
f.write(json.dumps(item) + '\n')
# Configure for answer training
args.answer_inbatch = True # Use neg_answer for in-batch training
# The dataset will:
# 1. Include answer in query encoding
# 2. Use neg_answer as additional negatives
# 3. Support distillation with answer scores
Sub-Batch Processing for Large Batches
# For training with large batches that don't fit in memory
collator = IREmbedderCollator(
tokenizer=tokenizer,
query_max_len=512,
passage_max_len=512,
sub_batch_size=64 # Process 64 samples at a time
)
# The collator returns lists of sub-batches
batch = next(iter(train_loader))
# batch['queries'] is now a list of sub-batches
for sub_queries, sub_passages in zip(batch['queries'], batch['passages']):
# Process each sub-batch separately
embeddings = model(sub_queries)
# Accumulate gradients...
Knowledge Distillation
# Prepare data with teacher scores
data_with_scores = {
"query": "What is machine learning?",
"pos": ["ML is a subset of AI..."],
"neg": ["AI is intelligence shown by machines...", ...],
"pos_scores": [0.95], # Teacher score for positive
"neg_scores": [0.82, 0.76, 0.71, ...] # Teacher scores for negatives
}
# Enable distillation
args.knowledge_distillation = True
# The dataset will include teacher_scores in output
batch = next(iter(train_loader))
teacher_scores = batch['teacher_scores'] # Used for distillation loss
# In training loop:
# loss = contrastive_loss + kl_div(student_scores, teacher_scores)