Implementation:Predibase Lorax Fused LayerNorm
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides fused LayerNorm and RMSNorm implementations that combine normalization with residual addition in a single kernel call for efficient transformer inference across CUDA, ROCm, and XPU platforms.
Description
This module defines optimized normalization layers used throughout transformer model implementations. It also monkey-patches torch.nn.LayerNorm with load and load_no_bias classmethods for convenient weight loading from model checkpoints.
FastLayerNorm: A platform-specific subclass of nn.LayerNorm that fuses dropout, residual addition, and layer normalization into a single kernel:
- On CUDA: Uses the dropout_layer_norm extension's dropout_add_ln_fwd kernel for hidden states up to dimension 8192. Falls back to standard PyTorch LayerNorm for larger dimensions.
- On ROCm: Performs residual addition in Python and delegates to the standard PyTorch super().forward().
- On XPU: Uses Intel Extension for PyTorch (ipex.llm.functional.add_layer_norm).
FastRMSNorm: A custom nn.Module implementing Root Mean Square Layer Normalization with fused residual connections:
- On CUDA: Uses dropout_layer_norm.dropout_add_ln_fwd with the RMSNorm flag (True) for hidden dimensions up to 8192. Falls back to a manual computation using torch.rsqrt for larger dimensions.
- On ROCm: Uses vLLM's ops.rms_norm kernel.
- On XPU: Uses ipex.llm.functional.add_rms_norm.
- The load classmethod loads the weight tensor from a checkpoint prefix.
Both classes accept an optional residual tensor in their forward methods. When provided, the residual is added to the hidden states before normalization. Both return a tuple of (normed_hidden_states, residual) to enable residual stream propagation through the model.
Usage
These layers are used by all transformer model implementations in LoRAX as drop-in replacements for standard LayerNorm and RMSNorm. FastLayerNorm is used by models like GPT-2 and BLOOM, while FastRMSNorm is used by LLaMA, Mistral, and similar architectures.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/layernorm.py
- Lines: 1-182
Signature
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
Import
from lorax_server.layers.layernorm import FastLayerNorm, FastRMSNorm
I/O Contract
Inputs (FastLayerNorm.forward)
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input hidden states to normalize |
| residual | torch.Tensor or None | No | Optional residual tensor to add before normalization |
Inputs (FastRMSNorm.__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| weight | torch.Tensor | Yes | Learnable scale parameter for RMSNorm |
| eps | float | Yes | Epsilon for numerical stability (variance_epsilon) |
Outputs
| Name | Type | Description |
|---|---|---|
| normed_hidden_states | torch.Tensor | Normalized hidden states |
| residual | torch.Tensor | Updated residual tensor for propagation to next layer |
Usage Examples
# Used internally by transformer model layers
from lorax_server.layers.layernorm import FastRMSNorm
rms_norm = FastRMSNorm.load(prefix="model.layers.0.input_layernorm", weights=weights, eps=1e-6)
normed_hidden_states, residual = rms_norm(hidden_states, residual)