Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Matryoshka Mistral Model Self Distillation

From Leeroopedia


Knowledge Sources
Domains Information Retrieval, Neural Reranking, Knowledge Distillation, Matryoshka Learning
Last Updated 2026-02-09 00:00 GMT

Overview

A custom Mistral-based model with token compression and layer-wise outputs for self-distillation training in Matryoshka rerankers.

Description

This implementation is identical in architecture to the compensation model but used during the self-distillation phase of Matryoshka reranker training. The model combines Mistral's transformer architecture with cost-aware token compression and multiple output heads at different layers. During self-distillation, the deepest layer (teacher) supervises shallower layers (students), enabling earlier layers to produce high-quality rankings without processing all transformer layers. The token compression mechanism uses attention weights to intelligently reduce passage length while preserving query and prompt information. This allows the model to maintain accuracy while significantly reducing inference costs through early-exit strategies.

Usage

Use this model architecture during the self-distillation fine-tuning phase of Matryoshka rerankers, where deeper layers teach shallower layers to produce accurate rankings, enabling efficient early-exit inference.

Code Reference

Source Location

Signature

class CostWiseMistralModel(MistralPreTrainedModel):
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        compress_layer: Optional[int] = None,
        compress_ratio: Optional[int] = None,
        cutoff_layers: Optional[List[int]] = None,
        **kwargs
    ) -> CostWiseModelOutputWithPast:
        # Forward with layer-wise outputs for distillation

class CostWiseMistralForCausalLM(MistralPreTrainedModel):
    def __init__(self, config):
        # Multiple heads for multi-layer distillation
        self.lm_head = nn.ModuleList([
            CostWiseHead(config.hidden_size, 1)
            for _ in range(config.start_layer, config.num_hidden_layers + 1, config.layer_sep)
        ])

@dataclass
class CostWiseCausalLMOutputWithPast(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None  # Tuple of logits from multiple layers
    attention_masks: Optional[Tuple[torch.FloatTensor]] = None

Import

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.models.mistral.modeling_mistral import (
    MistralPreTrainedModel, MistralDecoderLayer, MistralRMSNorm
)
from mistral_config import CostWiseMistralConfig

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes Tokenized input sequences
attention_mask torch.Tensor No Mask for padding tokens
compress_layer List[int] No Layer indices where compression should be applied
compress_ratio int No Token compression ratio (e.g., 2, 4, 8)
cutoff_layers List[int] No Layers to extract outputs for distillation
query_lengths torch.Tensor No Number of query tokens per sample
prompt_lengths torch.Tensor No Number of prompt tokens per sample

Outputs

Name Type Description
logits Tuple[torch.FloatTensor] Scores from each cutoff layer (student and teacher)
hidden_states Tuple[torch.FloatTensor] Normalized hidden states at cutoff layers
attention_masks Tuple[torch.FloatTensor] Attention masks after any compression
loss torch.FloatTensor Optional training loss (not used in layer-wise mode)

Usage Examples

from mistral_model import CostWiseMistralForCausalLM
from mistral_config import CostWiseMistralConfig

# Configure for self-distillation
config = CostWiseMistralConfig.from_pretrained(model_path)
config.layer_wise = True
config.start_layer = 8  # Start extracting from layer 8
config.layer_sep = 4    # Extract every 4 layers
config.num_hidden_layers = 32

model = CostWiseMistralForCausalLM.from_pretrained(
    model_path,
    config=config,
    torch_dtype=torch.bfloat16
)

# Forward pass for distillation training
cutoff_layers = [8, 12, 16, 20, 24, 28, 32]  # Teacher is layer 32
outputs = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    compress_layer=[8],
    compress_ratio=4,
    cutoff_layers=cutoff_layers,
    query_lengths=query_lengths,
    prompt_lengths=prompt_lengths
)

# Teacher-student distillation
teacher_logits = outputs.logits[-1]  # Deepest layer
student_logits = outputs.logits[:-1]  # All earlier layers

# Compute distillation loss (MSE or KL divergence)
distillation_loss = 0
for student in student_logits:
    distillation_loss += F.mse_loss(student, teacher_logits.detach())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment