Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding LLM Dense Retriever Data

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Information Retrieval, Large Language Models
Last Updated 2026-02-09 00:00 GMT

Overview

A sophisticated dataset implementation for training LLM-based dense retrievers with special token formatting and in-context learning support.

Description

This module implements a specialized dataset class for training large language models as dense retrievers. It extends standard approaches with LLM-specific features including special token formatting (`<instruct>`, `<query>`, `<response>`), in-context learning with 0-6 few-shot examples, symmetric task support for STS and clustering, dynamic batch construction where samples from the same dataset are grouped together, and task-specific prompt templates for retrieval, classification, and clustering.

The dataset loads from disk or HuggingFace Hub, shuffles across epochs while maintaining task boundaries, and constructs batches optimized for stable training with in-batch negatives.

Usage

Use this dataset when fine-tuning large language models (LLaMA, Mistral, etc.) as dense retrievers, especially when you want to leverage instruction-following capabilities and in-context learning for improved retrieval performance.

Code Reference

Source Location

Signature

class SameDatasetTrainDataset(Dataset):
    def __init__(
        self, args: DataArguments, batch_size: int, seed: int,
        tokenizer: PreTrainedTokenizer, process_index: int = 0, num_processes: int = 1
    ):
        """Dataset with same-dataset batching"""

def get_query_prompt(query: str, prompt: str, use_special_tokens: bool) -> str:
    """Format query with instruction prompt"""

class SameEmbedCollator(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors='pt') -> dict:
        """Collate batch with tokenization"""

Import

from research.llm_dense_retriever.finetune.data import (
    SameDatasetTrainDataset, SameEmbedCollator, get_query_prompt
)

I/O Contract

Inputs

Name Type Required Description
query str Yes Query text
prompt str Yes Task instruction
pos List[str] Yes Positive passages
neg List[str] Yes Negative passages
type str Yes Task type (retrieval/sts/clustering/class)
pos_scores List[float] No Teacher scores for positives
neg_scores List[float] No Teacher scores for negatives

Outputs

Name Type Description
query Tensor Padded query tensors
passage Tensor Padded passage tensors
messages List[str] Batch metadata
teacher_scores List[float] Distillation scores (optional)

Usage Examples

Basic Setup

from transformers import AutoTokenizer
from research.llm_dense_retriever.finetune.data import SameDatasetTrainDataset, SameEmbedCollator
from arguments import DataArguments

tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
args = DataArguments(
    train_data='path/to/train.jsonl',
    query_max_len=512,
    passage_max_len=512,
    train_group_size=8,
    use_special_tokens=True
)

dataset = SameDatasetTrainDataset(args, batch_size=32, seed=42, tokenizer=tokenizer)
collator = SameEmbedCollator(tokenizer, query_max_len=512, passage_max_len=512)

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=1, collate_fn=collator)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment