Implementation:Predibase Lorax Flash Llama Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Llama transformer implementation for LoRax inference serving with flash attention and LoRA adapter support.
Description
FlashLlamaForCausalLM implements the Llama architecture with flash attention for efficient batched inference. The module defines the full Llama model stack including configuration, normalization, attention, MLP, and causal language model layers with integrated LoRA adapter hooks for dynamic adapter serving.
The file contains seven classes organized as a layered architecture:
- LlamaConfig -- Configuration class extending PretrainedConfig with Llama-specific parameters such as vocab_size, hidden_size, num_attention_heads, num_key_value_heads (for GQA), rope_scaling, and rope_theta.
- LlamaRMSNorm -- RMS normalization with a fused dropout_layer_norm kernel for hidden dimensions up to 8192, falling back to a manual implementation for larger dimensions.
- FlashLlamaAttention -- Multi-head attention with grouped query attention (GQA) support, rotary position embeddings, flash attention for prefill and paged attention for decode. Supports adapter-aware Q/K/V/O projections via TensorParallelMultiAdapterLinear.
- LlamaMLP -- Gated MLP (SwiGLU) using fused gate-up projections with adapter-aware gate/up/down projection layers.
- FlashLlamaLayer -- Single transformer decoder layer combining attention and MLP with pre-norm RMS normalization.
- FlashLlamaModel -- Full transformer model stacking N decoder layers with embedding and final normalization.
- FlashLlamaForCausalLM -- Top-level causal language model class that wraps the model with token embeddings and a language model head supporting LoRA adapters.
The implementation supports FP8 KV cache quantization and tensor parallelism for multi-GPU serving.
Usage
Used internally by the LoRax server when serving Llama-based models. This is the foundational model implementation that several other architectures extend, including Granite and Solar. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_llama_modeling.py
- Lines: 1-624
Signature
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights, create_layer_fn=None):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs [batch_size, seq_len] |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill (None during decode) |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors per layer |
| block_tables | torch.Tensor | Yes | Block table indices for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache placement |
| seqlen | Seqlen | Yes | Sequence length metadata wrapper |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter weights and indices for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective KV cache population during prefill |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head output |
| cross_attention_states | Optional[torch.Tensor] | No | Cross-attention states for encoder-decoder setups |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits from the multi-adapter head, or None |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_llama_modeling import FlashLlamaForCausalLM
# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration