Implementation:FlagOpen FlagEmbedding LLM Dense Retriever Data
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Information Retrieval, Large Language Models |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A sophisticated dataset implementation for training LLM-based dense retrievers with special token formatting and in-context learning support.
Description
This module implements a specialized dataset class for training large language models as dense retrievers. It extends standard approaches with LLM-specific features including special token formatting (`<instruct>`, `<query>`, `<response>`), in-context learning with 0-6 few-shot examples, symmetric task support for STS and clustering, dynamic batch construction where samples from the same dataset are grouped together, and task-specific prompt templates for retrieval, classification, and clustering.
The dataset loads from disk or HuggingFace Hub, shuffles across epochs while maintaining task boundaries, and constructs batches optimized for stable training with in-batch negatives.
Usage
Use this dataset when fine-tuning large language models (LLaMA, Mistral, etc.) as dense retrievers, especially when you want to leverage instruction-following capabilities and in-context learning for improved retrieval performance.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_dense_retriever/finetune/data.py
- Lines: 1-421
Signature
class SameDatasetTrainDataset(Dataset):
def __init__(
self, args: DataArguments, batch_size: int, seed: int,
tokenizer: PreTrainedTokenizer, process_index: int = 0, num_processes: int = 1
):
"""Dataset with same-dataset batching"""
def get_query_prompt(query: str, prompt: str, use_special_tokens: bool) -> str:
"""Format query with instruction prompt"""
class SameEmbedCollator(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors='pt') -> dict:
"""Collate batch with tokenization"""
Import
from research.llm_dense_retriever.finetune.data import (
SameDatasetTrainDataset, SameEmbedCollator, get_query_prompt
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | str | Yes | Query text |
| prompt | str | Yes | Task instruction |
| pos | List[str] | Yes | Positive passages |
| neg | List[str] | Yes | Negative passages |
| type | str | Yes | Task type (retrieval/sts/clustering/class) |
| pos_scores | List[float] | No | Teacher scores for positives |
| neg_scores | List[float] | No | Teacher scores for negatives |
Outputs
| Name | Type | Description |
|---|---|---|
| query | Tensor | Padded query tensors |
| passage | Tensor | Padded passage tensors |
| messages | List[str] | Batch metadata |
| teacher_scores | List[float] | Distillation scores (optional) |
Usage Examples
Basic Setup
from transformers import AutoTokenizer
from research.llm_dense_retriever.finetune.data import SameDatasetTrainDataset, SameEmbedCollator
from arguments import DataArguments
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
args = DataArguments(
train_data='path/to/train.jsonl',
query_max_len=512,
passage_max_len=512,
train_group_size=8,
use_special_tokens=True
)
dataset = SameDatasetTrainDataset(args, batch_size=32, seed=42, tokenizer=tokenizer)
collator = SameEmbedCollator(tokenizer, query_max_len=512, passage_max_len=512)
from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=1, collate_fn=collator)