Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding LLM Reranker Layerwise Modeling

From Leeroopedia


Knowledge Sources
Domains Reranking, Knowledge_Distillation, Layer_Wise_Training
Last Updated 2026-02-09 00:00 GMT

Overview

Layer-wise reranking model with multiple prediction heads and self-distillation from deeper to shallower layers.

Description

BiEncoderModel for layer-wise reranking extends standard rerankers with:

Multi-layer prediction:

  • Attaches reranking heads to multiple transformer layers (from start_layer to final layer)
  • Each head predicts relevance scores independently
  • Uses output_hidden_states and cutoff_layers to extract intermediate representations

Training with self-distillation:

  • Computes standard contrastive loss on all layer predictions
  • Uses the deepest layer as teacher to distill knowledge to shallower layers
  • Student layers learn to match the teacher's softmax distribution via KL divergence
  • Combined loss = Σ(contrastive losses) + Σ(distillation losses)

Inference flexibility:

  • Can use any intermediate layer for prediction (early exit)
  • Deeper layers provide more accurate rankings but cost more compute
  • Shallower layers offer faster inference with slight accuracy trade-off

The encode() method returns a list of scores from each layer, while forward() combines all losses during training.

Usage

Use this for training efficient rerankers that can dynamically trade off accuracy for speed by selecting which layer to use for prediction.

Code Reference

Source Location

Signature

class BiEncoderModel(nn.Module):
    def __init__(self, model: None, tokenizer: AutoTokenizer = None,
                 train_batch_size: int = 4, start_layer: int = 8)

    def encode(self, features)
    def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]])

Import

from research.llm_reranker.finetune_for_layerwise.modeling import BiEncoderModel

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes Base LLM with layer-wise heads
tokenizer AutoTokenizer Yes Tokenizer for the model
train_batch_size int No Number of queries per batch (default: 4)
start_layer int No First layer with reranking head (default: 8)
pair Dict Yes Tokenized inputs with input_ids, attention_mask, labels, position_ids

Outputs

Name Type Description
loss Tensor Combined contrastive + distillation loss (training only)
scores List[Tensor] Relevance scores from each layer [layer_8_scores, ..., final_layer_scores]

Usage Examples

from transformers import AutoModelForCausalLM, AutoTokenizer
from research.llm_reranker.finetune_for_layerwise.modeling import BiEncoderModel

# Initialize model with layer-wise configuration
base_model = AutoModelForCausalLM.from_pretrained(
    "openbmb/MiniCPM-2B-sft-bf16",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-2B-sft-bf16")

model = BiEncoderModel(
    model=base_model,
    tokenizer=tokenizer,
    train_batch_size=4,
    start_layer=8  # Heads at layers 8-24
)

# Training
pair_inputs = {
    "input_ids": pair_ids,        # [32, seq_len] (4 queries × 8 passages)
    "attention_mask": pair_mask,
    "labels": labels
}

outputs = model(pair=pair_inputs)
loss = outputs.loss
# loss = contrastive_loss(layer_8) + ... + contrastive_loss(layer_24)
#      + distill_loss(layer_8←24) + ... + distill_loss(layer_23←24)
loss.backward()

# Inference with different layers
model.eval()
with torch.no_grad():
    all_scores = model.encode(pair_inputs)
    # all_scores[0]: scores from layer 8 (fastest)
    # all_scores[-1]: scores from layer 24 (most accurate)

    # Use layer 16 for balanced speed/accuracy
    layer_16_scores = all_scores[16-8]  # Adjusted index
    ranked_indices = torch.argsort(layer_16_scores, descending=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment