Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash Mistral Modeling

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Optimized Mistral transformer implementation for LoRax inference serving with flash attention v2, sliding window attention, and LoRA adapter support.

Description

FlashMistralForCausalLM implements the Mistral architecture with flash attention v2 for efficient batched inference. The module requires flash attention v2 (raises ImportError at import time if unavailable) and provides sliding window attention support as a distinguishing feature from the Llama implementation.

The file contains seven classes organized as a layered architecture:

  • MistralConfig -- Configuration class extending PretrainedConfig with a sliding_window parameter (default 4096) and default num_key_value_heads of 8 for GQA.
  • MistralRMSNorm -- RMS normalization with a fused dropout_layer_norm kernel for hidden dimensions up to 8192.
  • MistralAttention -- Multi-head attention with GQA support, rotary position embeddings, flash attention for prefill and paged attention for decode. Supports adapter-aware Q/K/V/O projections.
  • MistralMLP -- Gated MLP (SwiGLU) with fused gate-up projections and adapter-aware layers.
  • MistralLayer -- Single transformer decoder layer combining attention and MLP with pre-norm RMS normalization.
  • MistralModel -- Full transformer model stacking N decoder layers with final normalization.
  • FlashMistralForCausalLM -- Top-level causal language model that wraps the model with token embeddings and a language model head. Implements sliding window clamping during decode by limiting max_s and clamping seqlen to max_past.

The implementation supports FP8 KV cache quantization and tensor parallelism for multi-GPU serving.

Usage

Used internally by the LoRax server when serving Mistral-based models. Also serves as the text model backbone for multimodal models loaded via the VLM module. Loaded via the model registry when the model config type matches.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/flash_mistral_modeling.py
  • Lines: 1-644

Signature

class FlashMistralForCausalLM(torch.nn.Module):
    def __init__(self, prefix, config, weights, name=None):
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
        adapter_data: AdapterBatchData,
        prefill_cache_indices: Optional[torch.Tensor] = None,
        lm_head_indices: Optional[torch.Tensor] = None,
        skip_lm_head: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        ...

Import

from lorax_server.models.custom_modeling.flash_mistral_modeling import FlashMistralForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Token IDs [batch_size, seq_len]
position_ids torch.Tensor Yes Position indices for rotary embeddings
cu_seqlen_prefill Optional[torch.Tensor] Yes Cumulative sequence lengths for flash attention prefill (None during decode)
kv_cache List[Tuple[torch.Tensor, torch.Tensor]] Yes Key-value cache tensors per layer
block_tables torch.Tensor Yes Block table indices for paged attention
slots torch.Tensor Yes Slot indices for KV cache placement
seqlen Seqlen Yes Sequence length metadata wrapper
max_s int Yes Maximum sequence length in the batch
adapter_data AdapterBatchData Yes LoRA adapter weights and indices for the batch
prefill_cache_indices Optional[torch.Tensor] No Indices for selective KV cache population during prefill; also triggers slot slicing for sliding window
lm_head_indices Optional[torch.Tensor] No Indices to select specific positions for LM head output
skip_lm_head bool No If True, return hidden states without applying the LM head

Outputs

Name Type Description
logits torch.Tensor Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True)
speculative_logits Optional[torch.Tensor] Speculative decoding logits from the multi-adapter head, or None

Usage Examples

# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_mistral_modeling import FlashMistralForCausalLM

# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment