Implementation:Predibase Lorax Flash Gemma2 Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Gemma 2 transformer implementation for LoRax inference serving with flash attention, alternating sliding window attention, and LoRA adapter support.
Description
FlashGemma2ForCausalLM implements Google's Gemma 2 architecture with flash attention for efficient batched inference. The module features alternating sliding window and full causal attention across layers, a distinctive characteristic of the Gemma 2 model family.
The file contains seven classes organized as a layered architecture:
- Gemma2Config -- Configuration class extending PretrainedConfig with Gemma 2-specific parameters including an explicit head_dim (default 256), attention_bias, and attention_dropout settings.
- Gemma2FastRMSNorm -- Custom RMS normalization extending FastRMSNorm that adds 1 to learned weights and normalizes in float32 precision for numerical stability.
- FlashGemma2Attention -- Multi-head attention with GQA support, rotary position embeddings, and attention logit soft-capping. Supports alternating between full causal attention and sliding window attention based on layer index.
- Gemma2MLP -- Gated MLP with fused gate-up projections and adapter-aware gate/up/down projection layers.
- FlashGemma2Layer -- Single transformer decoder layer with separate pre/post attention and pre/post feedforward RMS normalization layers (four norm layers per block).
- FlashGemma2Model -- Full transformer model stacking N decoder layers with final normalization.
- FlashGemma2ForCausalLM -- Top-level causal language model that wraps the model with token embeddings (scaled by hidden_size**0.5) and a language model head supporting LoRA adapters and tied embeddings.
The implementation supports tensor parallelism for multi-GPU serving and uses the layers module (FastRMSNorm) from the lorax_server.layers package.
Usage
Used internally by the LoRax server when serving Google Gemma 2 models. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py
- Lines: 1-563
Signature
class FlashGemma2ForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, causal=True):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs [batch_size, seq_len] |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill (None during decode) |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors per layer |
| block_tables | torch.Tensor | Yes | Block table indices for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache placement |
| seqlen | Seqlen | Yes | Sequence length metadata wrapper |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter weights and indices for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective KV cache population during prefill |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head output |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits from the multi-adapter head, or None |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM
# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration