Implementation:Predibase Lorax Flash Mixtral Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Mixtral transformer implementation for LoRax inference serving with flash attention, LoRA adapter support, sliding window attention, and Mixture of Experts (MoE) feed-forward layers using the Megablocks library.
Description
This module implements the Mixtral Mixture of Experts architecture for high-throughput inference in the LoRax serving framework. Mixtral uses a sparse MoE approach where each token is routed to a subset of experts (top-k). The main components are:
- MixtralConfig -- Configuration class extending PretrainedConfig with Mixtral-specific parameters including sliding_window, num_experts_per_tok, and num_local_experts.
- MixtralRMSNorm -- RMS normalization with an optimized path using dropout_layer_norm for hidden dimensions up to 8192, and a fallback PyTorch path for larger dimensions.
- MixtralAttention -- Multi-head attention with grouped-query attention (GQA), rotary positional embeddings (RoPE), sliding window support, and paged attention. Supports LoRA adapters on Q, K, V, and O projections.
- BlockSparseMoE -- The primary MoE feed-forward layer using Megablocks for block-sparse matrix operations. Routes tokens to top-k experts efficiently without dropping any tokens. Uses sparse forward for batches larger than 256 tokens and dense forward otherwise.
- DenseMoE -- A fallback dense MoE implementation used when quantization is enabled, loading individual expert weights as separate TensorParallel layers.
- MixtralLayer -- A single transformer layer combining self-attention with RMS normalization and the MoE block.
- MixtralModel -- The full transformer model stacking embedding, multiple MixtralLayer instances, and final RMS normalization.
- FlashMixtralForCausalLM -- The top-level causal language model that wraps MixtralModel with an LM head, handles sliding window clamping during decode, and supports multi-adapter inference.
Usage
Used internally by the LoRax server when serving Mixtral-based models. Loaded via the model registry when the model config type matches. Requires flash_attn v2, megablocks, and stk libraries.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_mixtral_modeling.py
- Lines: 1-996
Signature
class FlashMixtralForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_mixtral_modeling import FlashMixtralForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs for the input sequence |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill; None during decode |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors for each layer |
| block_tables | torch.Tensor | Yes | Block tables for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache storage |
| seqlen | Seqlen | Yes | Sequence length information for the batch |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter configuration for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective cache prefilling with sliding window |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits over the vocabulary (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits, or None |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_mixtral_modeling import FlashMixtralForCausalLM
# Model instantiated by model registry, not directly by users