Implementation:FlagOpen FlagEmbedding MiniCPM Reranker Layerwise
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Information Retrieval, Reranking |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Layer-wise MiniCPM reranker model implementation supporting multi-layer output heads for efficient relevance scoring.
Description
This module implements a customized MiniCPM (Mini-CPM) language model architecture specifically designed for reranking tasks with layer-wise prediction capabilities. The LayerWiseMiniCPMForCausalLM model extends the base MiniCPM architecture by adding support for multiple output heads at different transformer layers, allowing the model to make predictions at various depths. This enables more flexible training strategies and early-exit inference for improved efficiency.
Key features include:
- Layer-wise cutoff for extracting hidden states at specific layers
- Multiple head types: raw (vocabulary), complex (learned), or reranking (single score)
- Multi-head support with separate heads per layer
- Flash Attention 2 and SDPA support for efficient computation
- Gradient checkpointing for memory-efficient training
- RoPE (Rotary Position Embedding) with scaling support
- Compatible with HuggingFace Transformers API
The model inherits from MiniCPM's efficient architecture with scaled embeddings and depth-wise residual connections, making it suitable for resource-constrained deployments.
Usage
Use this model for training layer-wise rerankers, implementing early-exit strategies for efficient inference, or fine-tuning MiniCPM-based models for document reranking tasks.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/llm_reranker/finetune_for_layerwise/modeling_minicpm_reranker.py
Signature
class LayerWiseMiniCPMModel(MiniCPMPreTrainedModel):
def __init__(self, config: LayerWiseMiniCPMConfig)
def forward(self, input_ids, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, use_cache=None,
output_attentions=None, output_hidden_states=None,
return_dict=None, cutoff_layers=None)
class LayerWiseMiniCPMForCausalLM(MiniCPMPreTrainedModel):
def __init__(self, config)
def forward(self, input_ids, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, labels=None,
use_cache=None, output_attentions=None,
output_hidden_states=None, return_dict=None,
cutoff_layers=None, only_for_one_logit=None)
def chat(self, tokenizer, query: str, history: List[Dict] = None,
role: str = "user", max_length: int = 4096,
num_beams=1, do_sample=True, top_p=0.8, temperature=0.3,
logits_processor=None, **kwargs)
# Attention implementations
class MiniCPMAttention(nn.Module)
class MiniCPMFlashAttention2(MiniCPMAttention)
class MiniCPMSdpaAttention(MiniCPMAttention)
# Layer components
class MiniCPMDecoderLayer(nn.Module)
class MiniCPMMLP(nn.Module)
class MiniCPMRMSNorm(nn.Module)
class MiniCPMRotaryEmbedding(nn.Module)
Import
from modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM
from transformers import AutoTokenizer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Input token IDs (batch_size, seq_len) |
| attention_mask | torch.Tensor | No | Attention mask for padding |
| cutoff_layers | Union[int, List[int]] | No | Layers to extract hidden states from |
| labels | torch.LongTensor | No | Labels for computing loss |
| only_for_one_logit | int | No | Extract single logit dimension (for reranking) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Training loss (if labels provided) |
| logits | Tuple[torch.Tensor] | Logits from each cutoff layer |
| hidden_states | Tuple[torch.Tensor] | Hidden states from cutoff layers |
| past_key_values | Tuple | Cached key-value pairs for generation |
Usage Examples
from modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM
from transformers import AutoTokenizer
import torch
# Load model and tokenizer
model = LayerWiseMiniCPMForCausalLM.from_pretrained(
"openbmb/MiniCPM-2B-sft-bf16",
trust_remote_code=True,
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-2B-sft-bf16")
# Configure for reranking with multiple layers
model.config.head_type = "rerank" # Single score output
model.config.head_multi = True # Separate heads per layer
model.config.start_layer = 20 # Start from layer 20
model.config.num_hidden_layers = 40
# Prepare query-document pair
query = "What is machine learning?"
document = "Machine learning is a subset of artificial intelligence..."
text = f"Query: {query}\nDocument: {document}\nRelevant:"
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
# Get relevance scores from multiple layers
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
cutoff_layers=[20, 25, 30, 35, 40], # Extract from multiple layers
only_for_one_logit=0 # Get single relevance score
)
# outputs.logits is a tuple of tensors, one per cutoff layer
layer_20_score = outputs.logits[0]
layer_40_score = outputs.logits[-1]
print(f"Layer 20 score: {layer_20_score}")
print(f"Layer 40 score: {layer_40_score}")
# Training with early exit supervision
labels = torch.ones_like(inputs["input_ids"]) * -100 # Ignore loss
loss_outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
labels=labels,
cutoff_layers=[25, 40]
)
total_loss = loss_outputs.loss
# Interactive chat mode
response, history = model.chat(
tokenizer,
query="Explain transformers",
max_length=2048,
temperature=0.7
)
print(f"Response: {response}")
# Save fine-tuned model
model.save_pretrained("minicpm-reranker-layerwise")
tokenizer.save_pretrained("minicpm-reranker-layerwise")