Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding LLM Embedder Retrieval Args

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Information Retrieval, Configuration Management
Last Updated 2026-02-09 00:00 GMT

Overview

A comprehensive configuration system for retrieval tasks supporting dense retrieval, BM25, reranking, and training with extensive hyperparameter management.

Description

This module provides dataclass-based configuration classes for the llm_embedder retrieval system. It defines 413 lines of structured arguments organized into six main configuration classes:

  • BaseArgs: Common settings for data paths, metrics, corpus handling
  • DenseRetrievalArgs: Dense retriever configuration (encoders, pooling, FAISS indexing)
  • BM25Args: BM25 configuration (Anserini integration, k1/b parameters)
  • RankerArgs: Cross-encoder reranker settings
  • RetrievalArgs: Unified retrieval interface combining dense + BM25
  • RetrievalTrainingArgs: Training hyperparameters (learning rate, batch size, loss weights)

The configuration system supports path resolution with the 'llm-embedder:' prefix for relative paths, automatic defaults for common use cases, flexible metric selection (MRR, Recall, NDCG, MAP), multiple retrieval methods (dense, BM25, random, none), and HuggingFace Transformers TrainingArguments integration.

Usage

Use these argument classes when configuring retrieval experiments, when setting up training runs for embedding models, or when you need standardized configuration management for reproducible retrieval research.

Code Reference

Source Location

Signature

@dataclass
class BaseArgs:
    """Base arguments for data and evaluation"""
    data_root: str = "/data/llm-embedder"
    train_data: Optional[List[str]] = None
    eval_data: Optional[str] = None
    corpus: str = None
    # ... metrics, cutoffs, save options

@dataclass
class DenseRetrievalArgs(BaseArgs):
    """Arguments for dense retrieval"""
    query_encoder: str = "BAAI/bge-base-en"
    key_encoder: str = "BAAI/bge-base-en"
    query_max_length: int = 256
    pooling_method: List[str] = field(default_factory=lambda: ["cls"])
    # ... FAISS configuration, batch sizes

@dataclass
class BM25Args(BaseArgs):
    """Arguments for BM25 retrieval"""
    anserini_dir: str = '/share/peitian/Apps/anserini'
    k1: float = 0.82
    b: float = 0.68
    # ... indexing/searching parameters

@dataclass
class RetrievalTrainingArgs(TrainingArguments):
    """Training configuration for retrieval models"""
    train_group_size: int = 8
    cos_temperature: float = 0.01
    contrastive_weight: float = 0.2
    distill_weight: float = 1.0
    # ... learning rate, batch size, scheduler

Import

from research.llm_embedder.src.retrieval.args import (
    BaseArgs, DenseRetrievalArgs, BM25Args, RankerArgs,
    RetrievalArgs, RetrievalTrainingArgs
)

Configuration Classes

BaseArgs

@dataclass
class BaseArgs:
    # Data paths
    model_cache_dir: Optional[str] = None
    dataset_cache_dir: Optional[str] = None
    data_root: str = "/data/llm-embedder"
    train_data: Optional[List[str]] = None
    eval_data: Optional[str] = None
    corpus: str = None
    key_template: str = "{title} {text}"  # How to format corpus entries

    # Evaluation metrics
    metrics: List[str] = ["mrr", "recall", "ndcg"]
    cutoffs: List[int] = [1, 5, 10, 100]

    # Negative mining
    filter_answers: bool = False
    max_neg_num: int = 100

    # Save/load options
    load_result: bool = False
    save_result: bool = True
    save_name: Optional[str] = None
    save_to_output: bool = False

DenseRetrievalArgs

@dataclass
class DenseRetrievalArgs(BaseArgs):
    # Model configuration
    query_encoder: str = "BAAI/bge-base-en"
    key_encoder: str = "BAAI/bge-base-en"
    tie_encoders: bool = True  # Share weights?

    # Tokenization
    query_max_length: int = 256
    key_max_length: int = 256
    truncation_side: str = "right"

    # Pooling
    pooling_method: List[str] = ["cls"]  # cls, mean, dense, decoder

    # Instructions
    add_instruction: bool = True
    version: str = "bge"

    # Dense search
    dense_metric: str = "cos"  # cos, ip, l2
    faiss_index_factory: str = "Flat"
    hits: int = 200
    batch_size: int = 1000

    # Caching
    load_encode: bool = False
    save_encode: bool = False
    load_index: bool = False
    save_index: bool = False
    embedding_name: str = "embeddings"

    # Compute
    dtype: str = "fp16"
    cpu: bool = False

BM25Args

@dataclass
class BM25Args(BaseArgs):
    # Anserini setup
    anserini_dir: str = '/share/peitian/Apps/anserini'

    # BM25 hyperparameters
    k1: float = 0.82
    b: float = 0.68

    # Indexing options
    storeDocvectors: bool = False
    language: str = "en"
    threads: int = 32

    # Retrieval
    hits: int = 200

    # Caching
    load_index: bool = False
    load_collection: bool = False

RetrievalTrainingArgs

@dataclass
class RetrievalTrainingArgs(TrainingArguments):
    # Output
    output_dir: str = 'data/outputs/'
    eval_method: str = "retrieval"

    # Training strategy
    use_train_config: bool = False
    inbatch_same_dataset: Optional[str] = None  # epoch, random
    negative_cross_device: bool = True

    # Loss configuration
    cos_temperature: float = 0.01
    teacher_temperature: float = 1.0
    student_temperature: float = 1.0
    contrastive_weight: float = 0.2
    distill_weight: float = 1.0
    stable_distill: bool = False

    # Data sampling
    max_sample_num: Optional[int] = None
    train_group_size: int = 8  # 1 pos + 7 negs
    select_positive: str = "first"
    select_negative: str = "random"
    teacher_scores_margin: Optional[float] = None
    teacher_scores_min: Optional[float] = None

    # Optimization
    per_device_train_batch_size: int = 16
    learning_rate: float = 5e-6
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01

    # Mixed precision
    fp16: bool = True

    # Distributed training
    ddp_find_unused_parameters: bool = False
    remove_unused_columns: bool = False

    # Evaluation
    evaluation_strategy: str = 'steps'
    save_steps: int = 2000
    logging_steps: int = 100
    early_exit_steps: Optional[int] = None

    # Logging
    report_to: str = "none"
    log_path: str = "data/results/performance.log"

Usage Examples

Parse Arguments from Command Line

from transformers import HfArgumentParser
from research.llm_embedder.src.retrieval.args import DenseRetrievalArgs

parser = HfArgumentParser([DenseRetrievalArgs])
args, = parser.parse_args_into_dataclasses()

print(f"Query encoder: {args.query_encoder}")
print(f"Batch size: {args.batch_size}")
print(f"Metrics: {args.metrics}")

Configure Dense Retrieval

from research.llm_embedder.src.retrieval.args import DenseRetrievalArgs

args = DenseRetrievalArgs(
    # Data
    data_root="/data/llm-embedder",
    eval_data="llm-embedder:beir/nfcorpus/test.json",
    corpus="llm-embedder:beir/nfcorpus/corpus.json",

    # Model
    query_encoder="BAAI/llm-embedder",
    key_encoder="BAAI/llm-embedder",
    tie_encoders=True,

    # Search
    dense_metric="cos",
    hits=100,
    batch_size=256,

    # Pooling
    pooling_method=["decoder"],  # Use decoder final token

    # Evaluation
    metrics=["recall", "ndcg"],
    cutoffs=[1, 5, 10, 20, 100]
)

# Path resolution happens automatically in __post_init__
print(args.eval_data)  # /data/llm-embedder/beir/nfcorpus/test.json

Configure Training

from research.llm_embedder.src.retrieval.args import RetrievalTrainingArgs

training_args = RetrievalTrainingArgs(
    # Output
    output_dir="outputs/llm-embedder-training",
    run_name="llama2-7b-retriever",

    # Data
    per_device_train_batch_size=8,
    train_group_size=8,  # 1 pos + 7 negs

    # Optimization
    learning_rate=2e-5,
    warmup_ratio=0.1,
    weight_decay=0.01,
    num_train_epochs=3,

    # Loss
    contrastive_weight=0.2,
    distill_weight=0.8,
    cos_temperature=0.02,

    # Negative sampling
    inbatch_same_dataset="epoch",
    negative_cross_device=True,

    # Evaluation
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,

    # Hardware
    fp16=True,
    gradient_checkpointing=True,

    # Logging
    logging_steps=10,
    report_to="wandb"
)

Configure BM25 + Reranking

from research.llm_embedder.src.retrieval.args import BM25Args, RankerArgs

# First-stage BM25
bm25_args = BM25Args(
    eval_data="data/test.json",
    corpus="data/corpus.json",
    anserini_dir="/opt/anserini",
    k1=0.9,
    b=0.4,
    hits=1000,  # Retrieve many candidates
    threads=16
)

# Second-stage reranking
ranker_args = RankerArgs(
    ranker="BAAI/bge-reranker-large",
    ranker_method="cross-encoder",
    batch_size=32,
    hits=100,  # Keep top-100 after reranking
    query_max_length=512,
    key_max_length=512
)

Path Resolution

from research.llm_embedder.src.retrieval.args import BaseArgs

args = BaseArgs(
    data_root="/mnt/data",
    train_data=["llm-embedder:train/file1.json", "llm-embedder:train/file2.json"],
    eval_data="llm-embedder:test/data.json",
    corpus="llm-embedder:corpus/passages.json"
)

# Paths are resolved in __post_init__:
# train_data: ["/mnt/data/train/file1.json", "/mnt/data/train/file2.json"]
# eval_data: "/mnt/data/test/data.json"
# corpus: "/mnt/data/corpus/passages.json"

# Can also use absolute paths directly
args2 = BaseArgs(
    train_data=["/absolute/path/train.json"],
    eval_data="/absolute/path/test.json"
)
# Absolute paths are left unchanged

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment