Implementation:FlagOpen FlagEmbedding Reinforced IR Retriever Modeling
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Neural Networks, Embedder Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
BiEncoder retrieval model with support for query augmentation, answer embeddings, and knowledge distillation for Reinforced IR.
Description
This module implements a sophisticated bi-encoder model for the Reinforced IR approach that supports multiple training paradigms. It extends the standard BiEncoderOnlyEmbedderModel with the ability to encode queries, passages, and LLM-generated answers separately, enabling flexible training objectives. The model supports three training types: retrieval-only (standard), retrieval+answer (queries and answers should retrieve passages), and retrieval+answer+passage (answers should also generate passages).
The architecture computes in-batch or cross-device negative losses for contrastive learning, with optional knowledge distillation from teacher scores. It includes special handling for answer embeddings with configurable temperature and normalization. The model can enforce unit norm constraints on answer representations to maintain consistent similarity scales. During training, it combines multiple loss terms: query-passage similarity, answer-passage similarity (if enabled), and passage reconstruction from answers (if enabled with reduced weight).
Usage
Use this model to train dense retrievers for the Reinforced IR framework, particularly when leveraging LLM-generated augmentations to improve retrieval performance through auxiliary training signals.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Reinforced_IR/finetune/retriever/modeling.py
- Lines: 1-215
Signature
class BiIREmbedderModel(BiEncoderOnlyEmbedderModel):
def __init__(
self,
base_model: AutoModel,
tokenizer: AutoTokenizer = None,
negatives_cross_device: bool = False,
temperature: float = 1.0,
answer_temperature: float = None,
sub_batch_size: int = -1,
kd_loss_type: str = 'kl_div',
sentence_pooling_method: str = 'cls',
normalize_embeddings: bool = False,
normalize_answer: bool = True,
training_type: str = 'retrieval_answer'
)
def forward(
self,
queries: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
answers: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
passages: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None,
teacher_scores: Union[None, List[float]] = None,
teacher_scores_answers: Union[None, List[float]] = None,
no_in_batch_neg_flag: bool = False
):
"""Forward pass with multi-objective training"""
Import
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from FlagEmbedding.abc.finetune.embedder.AbsModeling import AbsEmbedderModel, EmbedderOutput
from FlagEmbedding.finetune.embedder.encoder_only.base.modeling import BiEncoderOnlyEmbedderModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_model | AutoModel | Yes | Transformer model for encoding |
| tokenizer | AutoTokenizer | No | Tokenizer for the model |
| temperature | float | No | Temperature for query-passage similarity (default: 1.0) |
| answer_temperature | float | No | Temperature for answer-passage similarity (default: 0.05) |
| normalize_embeddings | bool | No | Normalize query/passage embeddings (default: False) |
| normalize_answer | bool | No | Enforce L2 norm=1 for answer embeddings (default: True) |
| training_type | str | Yes | Training mode: 'retrieval', 'retrieval_answer', 'retrieval_answer_passage' |
| kd_loss_type | str | No | Distillation loss type: 'kl_div' or 'm3_kd_loss' (default: 'kl_div') |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Combined loss from all enabled objectives |
Usage Examples
from transformers import AutoModel, AutoTokenizer
# Initialize model
base_model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
# Create Reinforced IR model
model = BiIREmbedderModel(
base_model=base_model,
tokenizer=tokenizer,
temperature=0.02,
answer_temperature=0.05,
normalize_embeddings=True,
normalize_answer=True,
training_type='retrieval_answer', # Train with query and answer objectives
kd_loss_type='kl_div'
)
# Training forward pass
outputs = model(
queries=batch['queries'],
answers=batch['answers'], # LLM-generated augmentations
passages=batch['passages'],
teacher_scores=batch.get('teacher_scores'),
no_in_batch_neg_flag=False
)
loss = outputs.loss
# Loss includes:
# 1. Query-passage contrastive loss
# 2. Answer-passage contrastive loss (if training_type includes 'answer')
# 3. Answer L2 norm regularization
# 4. Optional KD loss from teacher scores
# For retrieval+answer+passage mode:
# Also includes passage reconstruction loss from answers (weighted 0.25x)