Implementation:FlagOpen FlagEmbedding Matryoshka Mistral Model Inference
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Neural Reranking, Efficient Inference, Matryoshka Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A custom Mistral-based model with token compression and flexible layer cutoffs for efficient Matryoshka reranker inference.
Description
This implementation provides the same CostWiseMistralForCausalLM architecture used during training but optimized for inference scenarios. The model enables flexible early-exit inference by supporting arbitrary layer cutoffs, allowing users to trade off accuracy for speed. During inference, users can specify which layer's output to use (e.g., layer 8, 16, or 32), with deeper layers generally providing better accuracy at higher computational cost. The token compression mechanism can also be applied during inference to further reduce costs by compressing passage tokens based on attention weights. This architecture is the deployment version of Matryoshka rerankers, enabling cost-adaptive ranking where the computation depth can be adjusted based on query difficulty or latency requirements.
Usage
Use this model for efficient reranking inference where you can dynamically select the computation depth (layer) based on your accuracy-latency tradeoffs. Shallower layers provide faster inference while deeper layers offer better accuracy.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/inference/mistral_model.py
- Lines: 1-706
Signature
class CostWiseMistralForCausalLM(MistralPreTrainedModel):
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
compress_layer: Optional[int] = None,
compress_ratio: Optional[int] = None,
cutoff_layers: Optional[List[int]] = None,
**kwargs
) -> CostWiseCausalLMOutputWithPast:
# Flexible inference with optional early-exit
def token_compress(compress_ratio, hidden_states, attention_mask,
query_lengths, prompt_lengths, weights):
# Dynamic token compression for efficient inference
Import
import torch
from torch import nn
from transformers.models.mistral.modeling_mistral import (
MistralPreTrainedModel, MistralDecoderLayer
)
from mistral_config import CostWiseMistralConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Tokenized query-passage pairs |
| attention_mask | torch.Tensor | No | Attention mask for padding |
| compress_layer | List[int] | No | Layers at which to compress tokens (e.g., [8]) |
| compress_ratio | int | No | Compression ratio: 1 (no compression), 2, 4, or 8 |
| cutoff_layers | List[int] | No | Layer(s) from which to extract scores for ranking |
| query_lengths | torch.Tensor | No | Length of query portion in each sample |
| prompt_lengths | torch.Tensor | No | Length of prompt portion in each sample |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.FloatTensor or Tuple | Relevance scores from specified layer(s) |
| hidden_states | Tuple[torch.FloatTensor] | Hidden representations at cutoff layers |
| attention_masks | Tuple[torch.FloatTensor] | Final attention masks after compression |
Usage Examples
from mistral_model import CostWiseMistralForCausalLM
from mistral_config import CostWiseMistralConfig
# Load model for inference
config = CostWiseMistralConfig.from_pretrained(model_path)
config.layer_wise = True
config.start_layer = 4
config.layer_sep = 4
model = CostWiseMistralForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16
)
model.eval()
model.to('cuda')
# Inference at different depths
with torch.no_grad():
# Fast inference with layer 8
outputs_fast = model(
input_ids=input_ids,
attention_mask=attention_mask,
cutoff_layers=[8],
compress_layer=[8],
compress_ratio=4,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths
)
scores_fast = outputs_fast.logits[0]
# Accurate inference with layer 32
outputs_accurate = model(
input_ids=input_ids,
attention_mask=attention_mask,
cutoff_layers=[32],
compress_layer=[8],
compress_ratio=4,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths
)
scores_accurate = outputs_accurate.logits[0]
# Multi-layer inference for adaptive routing
outputs_multi = model(
input_ids=input_ids,
attention_mask=attention_mask,
cutoff_layers=[8, 16, 24, 32],
compress_layer=[8],
compress_ratio=4,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths
)
# Select layer based on confidence or query complexity