Implementation:FlagOpen FlagEmbedding LLM Embedder Retrieval Args
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Information Retrieval, Configuration Management |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A comprehensive configuration system for retrieval tasks supporting dense retrieval, BM25, reranking, and training with extensive hyperparameter management.
Description
This module provides dataclass-based configuration classes for the llm_embedder retrieval system. It defines 413 lines of structured arguments organized into six main configuration classes:
- BaseArgs: Common settings for data paths, metrics, corpus handling
- DenseRetrievalArgs: Dense retriever configuration (encoders, pooling, FAISS indexing)
- BM25Args: BM25 configuration (Anserini integration, k1/b parameters)
- RankerArgs: Cross-encoder reranker settings
- RetrievalArgs: Unified retrieval interface combining dense + BM25
- RetrievalTrainingArgs: Training hyperparameters (learning rate, batch size, loss weights)
The configuration system supports path resolution with the 'llm-embedder:' prefix for relative paths, automatic defaults for common use cases, flexible metric selection (MRR, Recall, NDCG, MAP), multiple retrieval methods (dense, BM25, random, none), and HuggingFace Transformers TrainingArguments integration.
Usage
Use these argument classes when configuring retrieval experiments, when setting up training runs for embedding models, or when you need standardized configuration management for reproducible retrieval research.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_embedder/src/retrieval/args.py
- Lines: 1-413
Signature
@dataclass
class BaseArgs:
"""Base arguments for data and evaluation"""
data_root: str = "/data/llm-embedder"
train_data: Optional[List[str]] = None
eval_data: Optional[str] = None
corpus: str = None
# ... metrics, cutoffs, save options
@dataclass
class DenseRetrievalArgs(BaseArgs):
"""Arguments for dense retrieval"""
query_encoder: str = "BAAI/bge-base-en"
key_encoder: str = "BAAI/bge-base-en"
query_max_length: int = 256
pooling_method: List[str] = field(default_factory=lambda: ["cls"])
# ... FAISS configuration, batch sizes
@dataclass
class BM25Args(BaseArgs):
"""Arguments for BM25 retrieval"""
anserini_dir: str = '/share/peitian/Apps/anserini'
k1: float = 0.82
b: float = 0.68
# ... indexing/searching parameters
@dataclass
class RetrievalTrainingArgs(TrainingArguments):
"""Training configuration for retrieval models"""
train_group_size: int = 8
cos_temperature: float = 0.01
contrastive_weight: float = 0.2
distill_weight: float = 1.0
# ... learning rate, batch size, scheduler
Import
from research.llm_embedder.src.retrieval.args import (
BaseArgs, DenseRetrievalArgs, BM25Args, RankerArgs,
RetrievalArgs, RetrievalTrainingArgs
)
Configuration Classes
BaseArgs
@dataclass
class BaseArgs:
# Data paths
model_cache_dir: Optional[str] = None
dataset_cache_dir: Optional[str] = None
data_root: str = "/data/llm-embedder"
train_data: Optional[List[str]] = None
eval_data: Optional[str] = None
corpus: str = None
key_template: str = "{title} {text}" # How to format corpus entries
# Evaluation metrics
metrics: List[str] = ["mrr", "recall", "ndcg"]
cutoffs: List[int] = [1, 5, 10, 100]
# Negative mining
filter_answers: bool = False
max_neg_num: int = 100
# Save/load options
load_result: bool = False
save_result: bool = True
save_name: Optional[str] = None
save_to_output: bool = False
DenseRetrievalArgs
@dataclass
class DenseRetrievalArgs(BaseArgs):
# Model configuration
query_encoder: str = "BAAI/bge-base-en"
key_encoder: str = "BAAI/bge-base-en"
tie_encoders: bool = True # Share weights?
# Tokenization
query_max_length: int = 256
key_max_length: int = 256
truncation_side: str = "right"
# Pooling
pooling_method: List[str] = ["cls"] # cls, mean, dense, decoder
# Instructions
add_instruction: bool = True
version: str = "bge"
# Dense search
dense_metric: str = "cos" # cos, ip, l2
faiss_index_factory: str = "Flat"
hits: int = 200
batch_size: int = 1000
# Caching
load_encode: bool = False
save_encode: bool = False
load_index: bool = False
save_index: bool = False
embedding_name: str = "embeddings"
# Compute
dtype: str = "fp16"
cpu: bool = False
BM25Args
@dataclass
class BM25Args(BaseArgs):
# Anserini setup
anserini_dir: str = '/share/peitian/Apps/anserini'
# BM25 hyperparameters
k1: float = 0.82
b: float = 0.68
# Indexing options
storeDocvectors: bool = False
language: str = "en"
threads: int = 32
# Retrieval
hits: int = 200
# Caching
load_index: bool = False
load_collection: bool = False
RetrievalTrainingArgs
@dataclass
class RetrievalTrainingArgs(TrainingArguments):
# Output
output_dir: str = 'data/outputs/'
eval_method: str = "retrieval"
# Training strategy
use_train_config: bool = False
inbatch_same_dataset: Optional[str] = None # epoch, random
negative_cross_device: bool = True
# Loss configuration
cos_temperature: float = 0.01
teacher_temperature: float = 1.0
student_temperature: float = 1.0
contrastive_weight: float = 0.2
distill_weight: float = 1.0
stable_distill: bool = False
# Data sampling
max_sample_num: Optional[int] = None
train_group_size: int = 8 # 1 pos + 7 negs
select_positive: str = "first"
select_negative: str = "random"
teacher_scores_margin: Optional[float] = None
teacher_scores_min: Optional[float] = None
# Optimization
per_device_train_batch_size: int = 16
learning_rate: float = 5e-6
warmup_ratio: float = 0.1
weight_decay: float = 0.01
# Mixed precision
fp16: bool = True
# Distributed training
ddp_find_unused_parameters: bool = False
remove_unused_columns: bool = False
# Evaluation
evaluation_strategy: str = 'steps'
save_steps: int = 2000
logging_steps: int = 100
early_exit_steps: Optional[int] = None
# Logging
report_to: str = "none"
log_path: str = "data/results/performance.log"
Usage Examples
Parse Arguments from Command Line
from transformers import HfArgumentParser
from research.llm_embedder.src.retrieval.args import DenseRetrievalArgs
parser = HfArgumentParser([DenseRetrievalArgs])
args, = parser.parse_args_into_dataclasses()
print(f"Query encoder: {args.query_encoder}")
print(f"Batch size: {args.batch_size}")
print(f"Metrics: {args.metrics}")
Configure Dense Retrieval
from research.llm_embedder.src.retrieval.args import DenseRetrievalArgs
args = DenseRetrievalArgs(
# Data
data_root="/data/llm-embedder",
eval_data="llm-embedder:beir/nfcorpus/test.json",
corpus="llm-embedder:beir/nfcorpus/corpus.json",
# Model
query_encoder="BAAI/llm-embedder",
key_encoder="BAAI/llm-embedder",
tie_encoders=True,
# Search
dense_metric="cos",
hits=100,
batch_size=256,
# Pooling
pooling_method=["decoder"], # Use decoder final token
# Evaluation
metrics=["recall", "ndcg"],
cutoffs=[1, 5, 10, 20, 100]
)
# Path resolution happens automatically in __post_init__
print(args.eval_data) # /data/llm-embedder/beir/nfcorpus/test.json
Configure Training
from research.llm_embedder.src.retrieval.args import RetrievalTrainingArgs
training_args = RetrievalTrainingArgs(
# Output
output_dir="outputs/llm-embedder-training",
run_name="llama2-7b-retriever",
# Data
per_device_train_batch_size=8,
train_group_size=8, # 1 pos + 7 negs
# Optimization
learning_rate=2e-5,
warmup_ratio=0.1,
weight_decay=0.01,
num_train_epochs=3,
# Loss
contrastive_weight=0.2,
distill_weight=0.8,
cos_temperature=0.02,
# Negative sampling
inbatch_same_dataset="epoch",
negative_cross_device=True,
# Evaluation
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
# Hardware
fp16=True,
gradient_checkpointing=True,
# Logging
logging_steps=10,
report_to="wandb"
)
Configure BM25 + Reranking
from research.llm_embedder.src.retrieval.args import BM25Args, RankerArgs
# First-stage BM25
bm25_args = BM25Args(
eval_data="data/test.json",
corpus="data/corpus.json",
anserini_dir="/opt/anserini",
k1=0.9,
b=0.4,
hits=1000, # Retrieve many candidates
threads=16
)
# Second-stage reranking
ranker_args = RankerArgs(
ranker="BAAI/bge-reranker-large",
ranker_method="cross-encoder",
batch_size=32,
hits=100, # Keep top-100 after reranking
query_max_length=512,
key_max_length=512
)
Path Resolution
from research.llm_embedder.src.retrieval.args import BaseArgs
args = BaseArgs(
data_root="/mnt/data",
train_data=["llm-embedder:train/file1.json", "llm-embedder:train/file2.json"],
eval_data="llm-embedder:test/data.json",
corpus="llm-embedder:corpus/passages.json"
)
# Paths are resolved in __post_init__:
# train_data: ["/mnt/data/train/file1.json", "/mnt/data/train/file2.json"]
# eval_data: "/mnt/data/test/data.json"
# corpus: "/mnt/data/corpus/passages.json"
# Can also use absolute paths directly
args2 = BaseArgs(
train_data=["/absolute/path/train.json"],
eval_data="/absolute/path/test.json"
)
# Absolute paths are left unchanged