Implementation:FlagOpen FlagEmbedding LLM Reranker Instruction Data
| Knowledge Sources | |
|---|---|
| Domains | Reranking, Instruction_Tuning, Data_Processing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Data loading and collation for training LLM-based rerankers with instruction-following capabilities.
Description
This module prepares training data for instruction-tuned rerankers:
TrainDatasetForReranker loads query-passage pairs with:
- Support for directories or single JSON files
- Sampling control via max_example_num_per_dataset
- Random selection of 1 positive + N-1 negatives per query
- Automatic padding/cycling of negatives if insufficient samples exist
Data formatting: Each training example formats as: [BOS]query\npassage\nprompt[Yes/No]
- Query and passage have optional instruction prefixes
- Separate max lengths for queries (with fallback to passage space) and passages (with fallback to query space)
- The prompt (e.g., "Is this passage relevant?") precedes the binary label
- Labels mask all tokens except the final answer token ("Yes"/"No")
RerankCollator handles:
- Batching variable-length sequences
- Label padding with -100 for masked positions
- Ensuring only the final token contributes to loss (teaching binary relevance judgment)
- Optional position IDs for models requiring explicit position tracking
This approach teaches LLMs to judge passage relevance by predicting "Yes"/"No" in an instruction-following format.
Usage
Use this for fine-tuning instruction-tuned LLMs (LLaMA, Mistral, etc.) as rerankers that output relevance judgments in natural language.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_reranker/finetune_for_instruction/data.py
- Lines: 1-192
Signature
class TrainDatasetForReranker(Dataset):
def __init__(self, args: DataArguments, tokenizer: PreTrainedTokenizer)
def __getitem__(self, item) -> List[BatchEncoding]
@dataclass
class RerankCollator(DataCollatorForSeq2Seq):
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features, return_tensors='pt')
Import
from research.llm_reranker.finetune_for_instruction.data import TrainDatasetForReranker, RerankCollator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| args | DataArguments | Yes | Configuration with train_data, query/passage instructions, train_group_size |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer for the LLM |
| features | List | Yes | List of batch encodings from dataset |
| query_max_len | int | No | Max query tokens (default: 32) |
| passage_max_len | int | No | Max passage tokens (default: 128) |
Outputs
| Name | Type | Description |
|---|---|---|
| passages_inputs | List[Dict] | List of tokenized sequences with input_ids, attention_mask, labels, position_ids |
| batch | Dict | Collated batch with "pair" key containing padded tensors |
Usage Examples
from transformers import AutoTokenizer
from research.llm_reranker.finetune_for_instruction.data import TrainDatasetForReranker, RerankCollator
from research.llm_reranker.finetune_for_instruction.arguments import DataArguments
# Initialize dataset
args = DataArguments(
train_data="rerank_train.jsonl",
train_group_size=8,
query_instruction_for_retrieval="",
passage_instruction_for_retrieval="",
query_max_len=256,
passage_max_len=512
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
dataset = TrainDatasetForReranker(args, tokenizer)
# Initialize collator
collator = RerankCollator(
tokenizer=tokenizer,
query_max_len=256,
passage_max_len=512
)
# Get a training example
passages_inputs = dataset[0] # List of 8 formatted sequences
# Each: "[BOS]Query: what is AI\nPassage: AI is...\nIs this passage relevant? Answer: Yes"
# Collate batch
batch = collator(passages_inputs)
# batch["pair"]["input_ids"]: [8, max_len] with padding
# batch["pair"]["labels"]: [8, max_len] with -100 except last token