Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLM Embedder Retrieval Trainer

From Leeroopedia


Knowledge Sources
Domains Machine_Learning, Model_Training, Retrieval_Evaluation
Last Updated 2026-02-09 00:00 GMT

Overview

Custom HuggingFace Trainer for training and evaluating retrieval models with specialized evaluation metrics.

Description

RetrievalTrainer extends HuggingFace Trainer with retrieval-specific evaluation:

Training: Inherits standard Trainer capabilities for gradient accumulation, mixed precision, distributed training, and checkpointing. Saves model weights, tokenizer, and training arguments together.

Evaluation: Implements two evaluation modes:

  • Retrieval mode: Encodes corpus into embeddings, builds FAISS/BM25 index, searches for top-k passages per query, computes nDCG/Recall/precision metrics
  • Rerank mode: Scores pre-retrieved candidates using cross-encoder, reranks them, and evaluates

Key features:

  • Handles index/embedding caching (load_index, save_index, load_encode, save_encode)
  • Supports cross-device result gathering in distributed evaluation
  • Temporarily disables inbatch_same_dataset during evaluation
  • Logs metrics to both console and file logger
  • Broadcasts evaluation metrics from rank 0 to all processes for synchronized training

The EarlyExitCallBack provides early stopping based on global step count rather than epochs.

Usage

Use this as a drop-in replacement for HuggingFace Trainer when training retrieval or reranking models that need specialized evaluation beyond standard loss metrics.

Code Reference

Source Location

Signature

class RetrievalTrainer(Trainer):
    def __init__(self, *args, corpus: Dataset, model_args, file_logger, **kwargs)
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval")
    def _save(self, output_dir=None, state_dict=None)

class EarlyExitCallBack(TrainerCallback):
    def __init__(self, early_exit_steps=None)
    def on_step_end(self, args, state, control, **kwargs)

Import

from research.llm_embedder.src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes Retrieval/reranking model to train
args TrainingArguments Yes Training configuration (includes eval_method)
corpus Dataset Yes Corpus for evaluation retrieval
train_dataset Dataset Yes Training data
eval_dataset Dataset Yes Evaluation queries
model_args ModelArguments Yes Model-specific args (hits, batch_size, etc.)
file_logger FileLogger Yes Logger for saving metrics
data_collator DataCollator Yes Collator for batching

Outputs

Name Type Description
metrics Dict Evaluation metrics (nDCG, Recall, etc.)
checkpoint Directory Saved model, tokenizer, and args
log_file File Metrics logged with configuration

Usage Examples

from transformers import TrainingArguments
from research.llm_embedder.src.retrieval.trainer import RetrievalTrainer, EarlyExitCallBack
from research.llm_embedder.src.utils.util import FileLogger

# Initialize trainer
training_args = TrainingArguments(
    output_dir="./output",
    eval_method="retrieval",  # or "rerank"
    per_device_train_batch_size=32,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=100,
)

trainer = RetrievalTrainer(
    model=model,
    args=training_args,
    corpus=corpus_dataset,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    model_args=model_args,
    file_logger=FileLogger("results.log"),
    data_collator=collator,
    tokenizer=tokenizer,
)

# Add early exit callback
trainer.add_callback(EarlyExitCallBack(early_exit_steps=5000))

# Train
trainer.train()

# Evaluate
metrics = trainer.evaluate()
# {'eval_ndcg@10': 0.512, 'eval_recall@10': 0.734, ...}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment