Implementation:FlagOpen FlagEmbedding Matryoshka Mistral Model Self Distillation
| Knowledge Sources | |
|---|---|
| Domains | Information Retrieval, Neural Reranking, Knowledge Distillation, Matryoshka Learning |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A custom Mistral-based model with token compression and layer-wise outputs for self-distillation training in Matryoshka rerankers.
Description
This implementation is identical in architecture to the compensation model but used during the self-distillation phase of Matryoshka reranker training. The model combines Mistral's transformer architecture with cost-aware token compression and multiple output heads at different layers. During self-distillation, the deepest layer (teacher) supervises shallower layers (students), enabling earlier layers to produce high-quality rankings without processing all transformer layers. The token compression mechanism uses attention weights to intelligently reduce passage length while preserving query and prompt information. This allows the model to maintain accuracy while significantly reducing inference costs through early-exit strategies.
Usage
Use this model architecture during the self-distillation fine-tuning phase of Matryoshka rerankers, where deeper layers teach shallower layers to produce accurate rankings, enabling efficient early-exit inference.
Code Reference
Source Location
- Repository: FlagOpen_FlagEmbedding
- File: research/Matroyshka_reranker/finetune/self_distillation/mistral_model.py
- Lines: 1-706
Signature
class CostWiseMistralModel(MistralPreTrainedModel):
def forward(
self,
input_ids: torch.LongTensor = None,
compress_layer: Optional[int] = None,
compress_ratio: Optional[int] = None,
cutoff_layers: Optional[List[int]] = None,
**kwargs
) -> CostWiseModelOutputWithPast:
# Forward with layer-wise outputs for distillation
class CostWiseMistralForCausalLM(MistralPreTrainedModel):
def __init__(self, config):
# Multiple heads for multi-layer distillation
self.lm_head = nn.ModuleList([
CostWiseHead(config.hidden_size, 1)
for _ in range(config.start_layer, config.num_hidden_layers + 1, config.layer_sep)
])
@dataclass
class CostWiseCausalLMOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None # Tuple of logits from multiple layers
attention_masks: Optional[Tuple[torch.FloatTensor]] = None
Import
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.mistral.modeling_mistral import (
MistralPreTrainedModel, MistralDecoderLayer, MistralRMSNorm
)
from mistral_config import CostWiseMistralConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.LongTensor | Yes | Tokenized input sequences |
| attention_mask | torch.Tensor | No | Mask for padding tokens |
| compress_layer | List[int] | No | Layer indices where compression should be applied |
| compress_ratio | int | No | Token compression ratio (e.g., 2, 4, 8) |
| cutoff_layers | List[int] | No | Layers to extract outputs for distillation |
| query_lengths | torch.Tensor | No | Number of query tokens per sample |
| prompt_lengths | torch.Tensor | No | Number of prompt tokens per sample |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | Tuple[torch.FloatTensor] | Scores from each cutoff layer (student and teacher) |
| hidden_states | Tuple[torch.FloatTensor] | Normalized hidden states at cutoff layers |
| attention_masks | Tuple[torch.FloatTensor] | Attention masks after any compression |
| loss | torch.FloatTensor | Optional training loss (not used in layer-wise mode) |
Usage Examples
from mistral_model import CostWiseMistralForCausalLM
from mistral_config import CostWiseMistralConfig
# Configure for self-distillation
config = CostWiseMistralConfig.from_pretrained(model_path)
config.layer_wise = True
config.start_layer = 8 # Start extracting from layer 8
config.layer_sep = 4 # Extract every 4 layers
config.num_hidden_layers = 32
model = CostWiseMistralForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16
)
# Forward pass for distillation training
cutoff_layers = [8, 12, 16, 20, 24, 28, 32] # Teacher is layer 32
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
compress_layer=[8],
compress_ratio=4,
cutoff_layers=cutoff_layers,
query_lengths=query_lengths,
prompt_lengths=prompt_lengths
)
# Teacher-student distillation
teacher_logits = outputs.logits[-1] # Deepest layer
student_logits = outputs.logits[:-1] # All earlier layers
# Compute distillation loss (MSE or KL divergence)
distillation_loss = 0
for student in student_logits:
distillation_loss += F.mse_loss(student, teacher_logits.detach())