Implementation:FlagOpen FlagEmbedding Matryoshka Mistral Model Compensation
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Neural Reranking, Token Compression, Matryoshka Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A custom Mistral-based model implementation with token compression and layer-wise output capabilities for Matryoshka reranker compensation training.
Description
This implementation extends the standard Mistral transformer architecture with cost-aware token compression and multi-layer output heads for Matryoshka reranking. The model introduces a token_compress function that dynamically reduces sequence length based on attention weights, preserving query and prompt tokens while compressing passage content. It supports layer-wise training where multiple classification heads can be attached at different layers (controlled by start_layer and layer_sep), enabling the model to produce rankings at various depths. This is specifically designed for the compensation phase of Matryoshka reranker fine-tuning, where earlier layers are compensated to maintain performance when using shallower inference.
Usage
Use this model architecture during the compensation fine-tuning phase of Matryoshka rerankers, where you need to train multiple output layers simultaneously while applying dynamic token compression to reduce computational costs.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/finetune/compensation/mistral_model.py
- Lines: 1-706
Signature
class CostWiseMistralModel(MistralPreTrainedModel):
def __init__(self, config: CostWiseMistralConfig):
# Initialize Mistral model with custom config
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
compress_layer: Optional[int] = None,
compress_ratio: Optional[int] = None,
cutoff_layers: Optional[List[int]] = None,
query_lengths: Optional[int] = None,
prompt_lengths: Optional[int] = None,
**kwargs
) -> Union[Tuple, CostWiseModelOutputWithPast]:
# Forward pass with optional compression
class CostWiseMistralForCausalLM(MistralPreTrainedModel):
def __init__(self, config):
# Initialize with layer-wise heads
def token_compress(compress_ratio, hidden_states, attention_mask,
query_lengths, prompt_lengths,
weights: torch.Tensor = None):
# Compress passage tokens using attention weights
Import
import torch
from torch import nn
from transformers.models.mistral.modeling_mistral import (
MistralRMSNorm, MistralDecoderLayer, MistralPreTrainedModel
)
from mistral_config import CostWiseMistralConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Input token IDs |
| attention_mask | torch.Tensor | No | Attention mask for valid tokens |
| compress_layer | List[int] | No | Layers at which to apply token compression |
| compress_ratio | int | No | Compression ratio (1, 2, 4, or 8) |
| cutoff_layers | List[int] | No | Layers at which to extract hidden states for output |
| query_lengths | torch.Tensor | No | Length of query tokens per sample |
| prompt_lengths | torch.Tensor | No | Length of prompt tokens per sample |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.FloatTensor or Tuple | Classification scores from each cutoff layer |
| hidden_states | Tuple[torch.FloatTensor] | Hidden states from specified layers |
| attention_masks | Tuple[torch.FloatTensor] | Attention masks after compression |
| past_key_values | Optional[Tuple] | Cached key-value pairs for generation |
Usage Examples
from mistral_model import CostWiseMistralForCausalLM
from mistral_config import CostWiseMistralConfig
# Initialize config with layer-wise settings
config = CostWiseMistralConfig.from_pretrained(model_path)
config.layer_wise = True
config.start_layer = 4
config.layer_sep = 1
# Load model
model = CostWiseMistralForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16
)
# Forward pass with compression at layer 8, ratio 4
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
compress_layer=[8],
compress_ratio=4,
cutoff_layers=[4, 8, 12, 16, 20, 24, 28, 32],
query_lengths=query_lengths,
prompt_lengths=prompt_lengths,
output_hidden_states=True
)
# Extract scores from multiple layers
for i, layer_logits in enumerate(outputs.logits):
scores = layer_logits.cpu().float()
print(f"Layer {cutoff_layers[i]} scores: {scores}")