Implementation:Predibase Lorax Flash Mistral Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Mistral transformer implementation for LoRax inference serving with flash attention v2, sliding window attention, and LoRA adapter support.
Description
FlashMistralForCausalLM implements the Mistral architecture with flash attention v2 for efficient batched inference. The module requires flash attention v2 (raises ImportError at import time if unavailable) and provides sliding window attention support as a distinguishing feature from the Llama implementation.
The file contains seven classes organized as a layered architecture:
- MistralConfig -- Configuration class extending PretrainedConfig with a sliding_window parameter (default 4096) and default num_key_value_heads of 8 for GQA.
- MistralRMSNorm -- RMS normalization with a fused dropout_layer_norm kernel for hidden dimensions up to 8192.
- MistralAttention -- Multi-head attention with GQA support, rotary position embeddings, flash attention for prefill and paged attention for decode. Supports adapter-aware Q/K/V/O projections.
- MistralMLP -- Gated MLP (SwiGLU) with fused gate-up projections and adapter-aware layers.
- MistralLayer -- Single transformer decoder layer combining attention and MLP with pre-norm RMS normalization.
- MistralModel -- Full transformer model stacking N decoder layers with final normalization.
- FlashMistralForCausalLM -- Top-level causal language model that wraps the model with token embeddings and a language model head. Implements sliding window clamping during decode by limiting max_s and clamping seqlen to max_past.
The implementation supports FP8 KV cache quantization and tensor parallelism for multi-GPU serving.
Usage
Used internally by the LoRax server when serving Mistral-based models. Also serves as the text model backbone for multimodal models loaded via the VLM module. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_mistral_modeling.py
- Lines: 1-644
Signature
class FlashMistralForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, name=None):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_mistral_modeling import FlashMistralForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs [batch_size, seq_len] |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill (None during decode) |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors per layer |
| block_tables | torch.Tensor | Yes | Block table indices for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache placement |
| seqlen | Seqlen | Yes | Sequence length metadata wrapper |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter weights and indices for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective KV cache population during prefill; also triggers slot slicing for sliding window |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head output |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits from the multi-adapter head, or None |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_mistral_modeling import FlashMistralForCausalLM
# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration