Implementation:Predibase Lorax Flash Phi3 Modeling
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Optimized Phi-3 transformer implementation for LoRax inference serving with flash attention, RMS normalization, fused gate-up projections, and LoRA adapter support.
Description
This module implements the Microsoft Phi-3 architecture adapted for high-throughput inference in the LoRax serving framework. Phi-3 uses a Llama-style architecture with RMS normalization, grouped-query attention, and a gated MLP, representing a significant architectural departure from Phi-1.5/2. The main components are:
- Phi3Config -- Configuration class extending PretrainedConfig with parameters including vocab_size, hidden_size, num_attention_heads, num_key_value_heads, rms_norm_eps, rope_theta, and rope_scaling.
- Phi3RMSNorm -- RMS normalization with an optimized fast path using dropout_layer_norm for hidden dimensions up to 8192, and a fallback PyTorch implementation for larger dimensions. Supports fused residual addition.
- FlashPhi3Attention -- Multi-head attention with grouped-query attention (GQA), full rotary positional embeddings (RoPE), and a fused qkv_proj projection. Supports LoRA adapters on qkv_proj and o_proj. Provides a get_query_key_value_weights helper method for weight inspection.
- Phi3MLP -- Gated feed-forward network with a fused gate_up_proj projection split into gate and up components, SiLU activation on the gate, and down_proj. Supports LoRA adapters on gate_up_proj and down_proj.
- FlashPhi3Layer -- A single transformer layer with pre-norm RMS normalization, self-attention, post-attention normalization, and MLP with residual connections.
- FlashPhi3Model -- The full transformer model with token embedding, stacked FlashPhi3Layer instances, and final RMS normalization. Computes rotary cos/sin once for all layers.
- FlashPhi3ForCausalLM -- The top-level causal language model wrapping FlashPhi3Model with a MultiAdapterHead LM head supporting LoRA on the output projection.
Usage
Used internally by the LoRax server when serving Phi-3-based models. Loaded via the model registry when the model config type matches.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/models/custom_modeling/flash_phi3_modeling.py
- Lines: 1-529
Signature
class FlashPhi3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: AdapterBatchData,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
skip_lm_head: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...
Import
from lorax_server.models.custom_modeling.flash_phi3_modeling import FlashPhi3ForCausalLM
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs for the input sequence |
| position_ids | torch.Tensor | Yes | Position indices for rotary embeddings |
| cu_seqlen_prefill | Optional[torch.Tensor] | Yes | Cumulative sequence lengths for flash attention prefill; None during decode |
| kv_cache | List[Tuple[torch.Tensor, torch.Tensor]] | Yes | Key-value cache tensors for each layer |
| block_tables | torch.Tensor | Yes | Block tables for paged attention |
| slots | torch.Tensor | Yes | Slot indices for KV cache storage |
| seqlen | Seqlen | Yes | Sequence length information for the batch |
| max_s | int | Yes | Maximum sequence length in the batch |
| adapter_data | AdapterBatchData | Yes | LoRA adapter configuration for the batch |
| prefill_cache_indices | Optional[torch.Tensor] | No | Indices for selective cache prefilling |
| lm_head_indices | Optional[torch.Tensor] | No | Indices to select specific positions for LM head |
| skip_lm_head | bool | No | If True, return hidden states without applying the LM head |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits over the vocabulary (or hidden states if skip_lm_head is True) |
| speculative_logits | Optional[torch.Tensor] | Speculative decoding logits, or None |
Usage Examples
# Internal usage within LoRax server
from lorax_server.models.custom_modeling.flash_phi3_modeling import FlashPhi3ForCausalLM
# Model instantiated by model registry, not directly by users