Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:FlagOpen FlagEmbedding MiniCPM Reranker Layerwise

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Information Retrieval, Reranking
Last Updated 2026-02-09 00:00 GMT

Overview

Layer-wise MiniCPM reranker model implementation supporting multi-layer output heads for efficient relevance scoring.

Description

This module implements a customized MiniCPM (Mini-CPM) language model architecture specifically designed for reranking tasks with layer-wise prediction capabilities. The LayerWiseMiniCPMForCausalLM model extends the base MiniCPM architecture by adding support for multiple output heads at different transformer layers, allowing the model to make predictions at various depths. This enables more flexible training strategies and early-exit inference for improved efficiency.

Key features include:

  • Layer-wise cutoff for extracting hidden states at specific layers
  • Multiple head types: raw (vocabulary), complex (learned), or reranking (single score)
  • Multi-head support with separate heads per layer
  • Flash Attention 2 and SDPA support for efficient computation
  • Gradient checkpointing for memory-efficient training
  • RoPE (Rotary Position Embedding) with scaling support
  • Compatible with HuggingFace Transformers API

The model inherits from MiniCPM's efficient architecture with scaled embeddings and depth-wise residual connections, making it suitable for resource-constrained deployments.

Usage

Use this model for training layer-wise rerankers, implementing early-exit strategies for efficient inference, or fine-tuning MiniCPM-based models for document reranking tasks.

Code Reference

Source Location

Signature

class LayerWiseMiniCPMModel(MiniCPMPreTrainedModel):
    def __init__(self, config: LayerWiseMiniCPMConfig)

    def forward(self, input_ids, attention_mask=None, position_ids=None,
                past_key_values=None, inputs_embeds=None, use_cache=None,
                output_attentions=None, output_hidden_states=None,
                return_dict=None, cutoff_layers=None)

class LayerWiseMiniCPMForCausalLM(MiniCPMPreTrainedModel):
    def __init__(self, config)

    def forward(self, input_ids, attention_mask=None, position_ids=None,
                past_key_values=None, inputs_embeds=None, labels=None,
                use_cache=None, output_attentions=None,
                output_hidden_states=None, return_dict=None,
                cutoff_layers=None, only_for_one_logit=None)

    def chat(self, tokenizer, query: str, history: List[Dict] = None,
             role: str = "user", max_length: int = 4096,
             num_beams=1, do_sample=True, top_p=0.8, temperature=0.3,
             logits_processor=None, **kwargs)

# Attention implementations
class MiniCPMAttention(nn.Module)
class MiniCPMFlashAttention2(MiniCPMAttention)
class MiniCPMSdpaAttention(MiniCPMAttention)

# Layer components
class MiniCPMDecoderLayer(nn.Module)
class MiniCPMMLP(nn.Module)
class MiniCPMRMSNorm(nn.Module)
class MiniCPMRotaryEmbedding(nn.Module)

Import

from modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM
from transformers import AutoTokenizer

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes Input token IDs (batch_size, seq_len)
attention_mask torch.Tensor No Attention mask for padding
cutoff_layers Union[int, List[int]] No Layers to extract hidden states from
labels torch.LongTensor No Labels for computing loss
only_for_one_logit int No Extract single logit dimension (for reranking)

Outputs

Name Type Description
loss torch.Tensor Training loss (if labels provided)
logits Tuple[torch.Tensor] Logits from each cutoff layer
hidden_states Tuple[torch.Tensor] Hidden states from cutoff layers
past_key_values Tuple Cached key-value pairs for generation

Usage Examples

from modeling_minicpm_reranker import LayerWiseMiniCPMForCausalLM
from transformers import AutoTokenizer
import torch

# Load model and tokenizer
model = LayerWiseMiniCPMForCausalLM.from_pretrained(
    "openbmb/MiniCPM-2B-sft-bf16",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-2B-sft-bf16")

# Configure for reranking with multiple layers
model.config.head_type = "rerank"  # Single score output
model.config.head_multi = True     # Separate heads per layer
model.config.start_layer = 20      # Start from layer 20
model.config.num_hidden_layers = 40

# Prepare query-document pair
query = "What is machine learning?"
document = "Machine learning is a subset of artificial intelligence..."
text = f"Query: {query}\nDocument: {document}\nRelevant:"

inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)

# Get relevance scores from multiple layers
outputs = model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    cutoff_layers=[20, 25, 30, 35, 40],  # Extract from multiple layers
    only_for_one_logit=0  # Get single relevance score
)

# outputs.logits is a tuple of tensors, one per cutoff layer
layer_20_score = outputs.logits[0]
layer_40_score = outputs.logits[-1]
print(f"Layer 20 score: {layer_20_score}")
print(f"Layer 40 score: {layer_40_score}")

# Training with early exit supervision
labels = torch.ones_like(inputs["input_ids"]) * -100  # Ignore loss
loss_outputs = model(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    labels=labels,
    cutoff_layers=[25, 40]
)
total_loss = loss_outputs.loss

# Interactive chat mode
response, history = model.chat(
    tokenizer,
    query="Explain transformers",
    max_length=2048,
    temperature=0.7
)
print(f"Response: {response}")

# Save fine-tuned model
model.save_pretrained("minicpm-reranker-layerwise")
tokenizer.save_pretrained("minicpm-reranker-layerwise")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment