Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash Gemma2 Modeling

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Optimized Gemma 2 transformer implementation for LoRax inference serving with flash attention, alternating sliding window attention, and LoRA adapter support.

Description

FlashGemma2ForCausalLM implements Google's Gemma 2 architecture with flash attention for efficient batched inference. The module features alternating sliding window and full causal attention across layers, a distinctive characteristic of the Gemma 2 model family.

The file contains seven classes organized as a layered architecture:

  • Gemma2Config -- Configuration class extending PretrainedConfig with Gemma 2-specific parameters including an explicit head_dim (default 256), attention_bias, and attention_dropout settings.
  • Gemma2FastRMSNorm -- Custom RMS normalization extending FastRMSNorm that adds 1 to learned weights and normalizes in float32 precision for numerical stability.
  • FlashGemma2Attention -- Multi-head attention with GQA support, rotary position embeddings, and attention logit soft-capping. Supports alternating between full causal attention and sliding window attention based on layer index.
  • Gemma2MLP -- Gated MLP with fused gate-up projections and adapter-aware gate/up/down projection layers.
  • FlashGemma2Layer -- Single transformer decoder layer with separate pre/post attention and pre/post feedforward RMS normalization layers (four norm layers per block).
  • FlashGemma2Model -- Full transformer model stacking N decoder layers with final normalization.
  • FlashGemma2ForCausalLM -- Top-level causal language model that wraps the model with token embeddings (scaled by hidden_size**0.5) and a language model head supporting LoRA adapters and tied embeddings.

The implementation supports tensor parallelism for multi-GPU serving and uses the layers module (FastRMSNorm) from the lorax_server.layers package.

Usage

Used internally by the LoRax server when serving Google Gemma 2 models. Loaded via the model registry when the model config type matches.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/flash_gemma2_modeling.py
  • Lines: 1-563

Signature

class FlashGemma2ForCausalLM(torch.nn.Module):
    def __init__(self, prefix, config, weights, causal=True):
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
        adapter_data: AdapterBatchData,
        prefill_cache_indices: Optional[torch.Tensor] = None,
        lm_head_indices: Optional[torch.Tensor] = None,
        skip_lm_head: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        ...

Import

from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Token IDs [batch_size, seq_len]
position_ids torch.Tensor Yes Position indices for rotary embeddings
cu_seqlen_prefill Optional[torch.Tensor] Yes Cumulative sequence lengths for flash attention prefill (None during decode)
kv_cache List[Tuple[torch.Tensor, torch.Tensor]] Yes Key-value cache tensors per layer
block_tables torch.Tensor Yes Block table indices for paged attention
slots torch.Tensor Yes Slot indices for KV cache placement
seqlen Seqlen Yes Sequence length metadata wrapper
max_s int Yes Maximum sequence length in the batch
adapter_data AdapterBatchData Yes LoRA adapter weights and indices for the batch
prefill_cache_indices Optional[torch.Tensor] No Indices for selective KV cache population during prefill
lm_head_indices Optional[torch.Tensor] No Indices to select specific positions for LM head output
skip_lm_head bool No If True, return hidden states without applying the LM head

Outputs

Name Type Description
logits torch.Tensor Next-token logits [batch_size, vocab_size] (or hidden states if skip_lm_head is True)
speculative_logits Optional[torch.Tensor] Speculative decoding logits from the multi-adapter head, or None

Usage Examples

# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_gemma2_modeling import FlashGemma2ForCausalLM

# Model is instantiated by the model registry, not directly by users
# See server/lorax_server/models/__init__.py for registration

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment