Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:FlagOpen FlagEmbedding Matryoshka Mistral Model Compensation

From Leeroopedia


Knowledge Sources
Domains Information Retrieval, Neural Reranking, Token Compression, Matryoshka Learning
Last Updated 2026-02-09 00:00 GMT

Overview

A custom Mistral-based model implementation with token compression and layer-wise output capabilities for Matryoshka reranker compensation training.

Description

This implementation extends the standard Mistral transformer architecture with cost-aware token compression and multi-layer output heads for Matryoshka reranking. The model introduces a token_compress function that dynamically reduces sequence length based on attention weights, preserving query and prompt tokens while compressing passage content. It supports layer-wise training where multiple classification heads can be attached at different layers (controlled by start_layer and layer_sep), enabling the model to produce rankings at various depths. This is specifically designed for the compensation phase of Matryoshka reranker fine-tuning, where earlier layers are compensated to maintain performance when using shallower inference.

Usage

Use this model architecture during the compensation fine-tuning phase of Matryoshka rerankers, where you need to train multiple output layers simultaneously while applying dynamic token compression to reduce computational costs.

Code Reference

Source Location

Signature

class CostWiseMistralModel(MistralPreTrainedModel):
    def __init__(self, config: CostWiseMistralConfig):
        # Initialize Mistral model with custom config

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        compress_layer: Optional[int] = None,
        compress_ratio: Optional[int] = None,
        cutoff_layers: Optional[List[int]] = None,
        query_lengths: Optional[int] = None,
        prompt_lengths: Optional[int] = None,
        **kwargs
    ) -> Union[Tuple, CostWiseModelOutputWithPast]:
        # Forward pass with optional compression

class CostWiseMistralForCausalLM(MistralPreTrainedModel):
    def __init__(self, config):
        # Initialize with layer-wise heads

def token_compress(compress_ratio, hidden_states, attention_mask,
                   query_lengths, prompt_lengths,
                   weights: torch.Tensor = None):
    # Compress passage tokens using attention weights

Import

import torch
from torch import nn
from transformers.models.mistral.modeling_mistral import (
    MistralRMSNorm, MistralDecoderLayer, MistralPreTrainedModel
)
from mistral_config import CostWiseMistralConfig

I/O Contract

Inputs

Name Type Required Description
input_ids torch.LongTensor Yes Input token IDs
attention_mask torch.Tensor No Attention mask for valid tokens
compress_layer List[int] No Layers at which to apply token compression
compress_ratio int No Compression ratio (1, 2, 4, or 8)
cutoff_layers List[int] No Layers at which to extract hidden states for output
query_lengths torch.Tensor No Length of query tokens per sample
prompt_lengths torch.Tensor No Length of prompt tokens per sample

Outputs

Name Type Description
logits torch.FloatTensor or Tuple Classification scores from each cutoff layer
hidden_states Tuple[torch.FloatTensor] Hidden states from specified layers
attention_masks Tuple[torch.FloatTensor] Attention masks after compression
past_key_values Optional[Tuple] Cached key-value pairs for generation

Usage Examples

from mistral_model import CostWiseMistralForCausalLM
from mistral_config import CostWiseMistralConfig

# Initialize config with layer-wise settings
config = CostWiseMistralConfig.from_pretrained(model_path)
config.layer_wise = True
config.start_layer = 4
config.layer_sep = 1

# Load model
model = CostWiseMistralForCausalLM.from_pretrained(
    model_path,
    config=config,
    torch_dtype=torch.bfloat16
)

# Forward pass with compression at layer 8, ratio 4
outputs = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    compress_layer=[8],
    compress_ratio=4,
    cutoff_layers=[4, 8, 12, 16, 20, 24, 28, 32],
    query_lengths=query_lengths,
    prompt_lengths=prompt_lengths,
    output_hidden_states=True
)

# Extract scores from multiple layers
for i, layer_logits in enumerate(outputs.logits):
    scores = layer_logits.cpu().float()
    print(f"Layer {cutoff_layers[i]} scores: {scores}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment