Implementation:Predibase Lorax Flash Gemma Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Gemma transformer implementation for LoRax inference serving with flash attention and LoRA adapter support.
Description
GemmaForCausalLM implements Google's Gemma architecture with flash attention for efficient batched inference. The module provides attention layers, MLP blocks, and the full model stack with integrated LoRA adapter hooks for dynamic adapter serving.
The file contains seven classes organized as a layered architecture:
- GemmaConfig -- Configuration class extending PretrainedConfig with standard Llama-like parameters including rope_scaling and rope_theta.
- GemmaRMSNorm -- Custom RMS normalization that applies (1 + weight) scaling, differing from the standard Llama RMSNorm by adding 1 to the learned weight parameter.
- GemmaAttention -- Multi-head attention with GQA support using head_dim from config, rotary position embeddings, flash attention for prefill and paged attention for decode. Supports adapter-aware Q/K/V/O projections.
- GemmaMLP -- Gated MLP with fused gate-up projections and adapter-aware gate/up/down projection layers.
- GemmaDecoderLayer -- Single transformer decoder layer combining attention and MLP with pre-norm RMS normalization and post-attention normalization.
- GemmaModel -- Full transformer model stacking N decoder layers with token embeddings (scaled by hidden_size**0.5) and final normalization.
- GemmaForCausalLM -- Top-level causal language model that wraps the model and reuses the embedding weights for the language model head (tied embeddings), computing logits via a transposed matrix multiply.
The implementation supports FP8 KV cache quantization and tensor parallelism for multi-GPU serving.
Usage
Used internally by the LoRax server when serving Google Gemma models. Also used as the text backbone in PaliGemma vision-language models. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_gemma_modeling.py
- Lines: 1-560
Signature
class GemmaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_gemma_modeling import GemmaForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs [batch_size, seq_len] |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill (None during decode) |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors per layer |
| block_tables | torch.Tensor | Yes | Block table indices for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache placement |
| seqlen | Seqlen | Yes | Sequence length metadata wrapper |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter weights and indices for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective KV cache population during prefill |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head output |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits [batch_size, vocab_size] computed via tied embedding weights |
| speculative_logits | Optional[torch.Tensor] | Always None for Gemma (no speculative decoding via multi-adapter head) |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_gemma_modeling import GemmaForCausalLM
# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration