Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash Phi3 Modeling

From Leeroopedia
Revision as of 16:20, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_Flash_Phi3_Modeling.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Optimized Phi-3 transformer implementation for LoRax inference serving with flash attention, RMS normalization, fused gate-up projections, and LoRA adapter support.

Description

This module implements the Microsoft Phi-3 architecture adapted for high-throughput inference in the LoRax serving framework. Phi-3 uses a Llama-style architecture with RMS normalization, grouped-query attention, and a gated MLP, representing a significant architectural departure from Phi-1.5/2. The main components are:

  • Phi3Config -- Configuration class extending PretrainedConfig with parameters including vocab_size, hidden_size, num_attention_heads, num_key_value_heads, rms_norm_eps, rope_theta, and rope_scaling.
  • Phi3RMSNorm -- RMS normalization with an optimized fast path using dropout_layer_norm for hidden dimensions up to 8192, and a fallback PyTorch implementation for larger dimensions. Supports fused residual addition.
  • FlashPhi3Attention -- Multi-head attention with grouped-query attention (GQA), full rotary positional embeddings (RoPE), and a fused qkv_proj projection. Supports LoRA adapters on qkv_proj and o_proj. Provides a get_query_key_value_weights helper method for weight inspection.
  • Phi3MLP -- Gated feed-forward network with a fused gate_up_proj projection split into gate and up components, SiLU activation on the gate, and down_proj. Supports LoRA adapters on gate_up_proj and down_proj.
  • FlashPhi3Layer -- A single transformer layer with pre-norm RMS normalization, self-attention, post-attention normalization, and MLP with residual connections.
  • FlashPhi3Model -- The full transformer model with token embedding, stacked FlashPhi3Layer instances, and final RMS normalization. Computes rotary cos/sin once for all layers.
  • FlashPhi3ForCausalLM -- The top-level causal language model wrapping FlashPhi3Model with a MultiAdapterHead LM head supporting LoRA on the output projection.

Usage

Used internally by the LoRax server when serving Phi-3-based models. Loaded via the model registry when the model config type matches.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/custom_modeling/flash_phi3_modeling.py
  • Lines: 1-529

Signature

class FlashPhi3ForCausalLM(torch.nn.Module):
    def __init__(self, prefix: str, config, weights):
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        cu_seqlen_prefill: Optional[torch.Tensor],
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
        adapter_data: AdapterBatchData,
        prefill_cache_indices: Optional[torch.Tensor] = None,
        lm_head_indices: Optional[torch.Tensor] = None,
        skip_lm_head: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        ...

Import

from lorax_server.models.custom_modeling.flash_phi3_modeling import FlashPhi3ForCausalLM

I/O Contract

Inputs

Name Type Required Description
input_ids torch.Tensor Yes Token IDs for the input sequence
position_ids torch.Tensor Yes Position indices for rotary embeddings
cu_seqlen_prefill Optional[torch.Tensor] Yes Cumulative sequence lengths for flash attention prefill; None during decode
kv_cache List[Tuple[torch.Tensor, torch.Tensor]] Yes Key-value cache tensors for each layer
block_tables torch.Tensor Yes Block tables for paged attention
slots torch.Tensor Yes Slot indices for KV cache storage
seqlen Seqlen Yes Sequence length information for the batch
max_s int Yes Maximum sequence length in the batch
adapter_data AdapterBatchData Yes LoRA adapter configuration for the batch
prefill_cache_indices Optional[torch.Tensor] No Indices for selective cache prefilling
lm_head_indices Optional[torch.Tensor] No Indices to select specific positions for LM head
skip_lm_head bool No If True, return hidden states without applying the LM head

Outputs

Name Type Description
logits torch.Tensor Next-token logits over the vocabulary (or hidden states if skip_lm_head is True)
speculative_logits Optional[torch.Tensor] Speculative decoding logits, or None

Usage Examples

# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_phi3_modeling import FlashPhi3ForCausalLM
# Model instantiated by model registry, not directly by users

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment