Implementation:FlagOpen FlagEmbedding LLM Embedder Retrieval Trainer
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Model_Training, Retrieval_Evaluation |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Custom HuggingFace Trainer for training and evaluating retrieval models with specialized evaluation metrics.
Description
RetrievalTrainer extends HuggingFace Trainer with retrieval-specific evaluation:
Training: Inherits standard Trainer capabilities for gradient accumulation, mixed precision, distributed training, and checkpointing. Saves model weights, tokenizer, and training arguments together.
Evaluation: Implements two evaluation modes:
- Retrieval mode: Encodes corpus into embeddings, builds FAISS/BM25 index, searches for top-k passages per query, computes nDCG/Recall/precision metrics
- Rerank mode: Scores pre-retrieved candidates using cross-encoder, reranks them, and evaluates
Key features:
- Handles index/embedding caching (load_index, save_index, load_encode, save_encode)
- Supports cross-device result gathering in distributed evaluation
- Temporarily disables inbatch_same_dataset during evaluation
- Logs metrics to both console and file logger
- Broadcasts evaluation metrics from rank 0 to all processes for synchronized training
The EarlyExitCallBack provides early stopping based on global step count rather than epochs.
Usage
Use this as a drop-in replacement for HuggingFace Trainer when training retrieval or reranking models that need specialized evaluation beyond standard loss metrics.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_embedder/src/retrieval/trainer.py
- Lines: 1-199
Signature
class RetrievalTrainer(Trainer):
def __init__(self, *args, corpus: Dataset, model_args, file_logger, **kwargs)
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval")
def _save(self, output_dir=None, state_dict=None)
class EarlyExitCallBack(TrainerCallback):
def __init__(self, early_exit_steps=None)
def on_step_end(self, args, state, control, **kwargs)
Import
from research.llm_embedder.src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PreTrainedModel | Yes | Retrieval/reranking model to train |
| args | TrainingArguments | Yes | Training configuration (includes eval_method) |
| corpus | Dataset | Yes | Corpus for evaluation retrieval |
| train_dataset | Dataset | Yes | Training data |
| eval_dataset | Dataset | Yes | Evaluation queries |
| model_args | ModelArguments | Yes | Model-specific args (hits, batch_size, etc.) |
| file_logger | FileLogger | Yes | Logger for saving metrics |
| data_collator | DataCollator | Yes | Collator for batching |
Outputs
| Name | Type | Description |
|---|---|---|
| metrics | Dict | Evaluation metrics (nDCG, Recall, etc.) |
| checkpoint | Directory | Saved model, tokenizer, and args |
| log_file | File | Metrics logged with configuration |
Usage Examples
from transformers import TrainingArguments
from research.llm_embedder.src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from research.llm_embedder.src.utils.util import FileLogger
# Initialize trainer
training_args = TrainingArguments(
output_dir="./output",
eval_method="retrieval", # or "rerank"
per_device_train_batch_size=32,
evaluation_strategy="steps",
eval_steps=500,
save_steps=500,
logging_steps=100,
)
trainer = RetrievalTrainer(
model=model,
args=training_args,
corpus=corpus_dataset,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model_args=model_args,
file_logger=FileLogger("results.log"),
data_collator=collator,
tokenizer=tokenizer,
)
# Add early exit callback
trainer.add_callback(EarlyExitCallBack(early_exit_steps=5000))
# Train
trainer.train()
# Evaluate
metrics = trainer.evaluate()
# {'eval_ndcg@10': 0.512, 'eval_recall@10': 0.734, ...}