Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Fused LayerNorm

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides fused LayerNorm and RMSNorm implementations that combine normalization with residual addition in a single kernel call for efficient transformer inference across CUDA, ROCm, and XPU platforms.

Description

This module defines optimized normalization layers used throughout transformer model implementations. It also monkey-patches torch.nn.LayerNorm with load and load_no_bias classmethods for convenient weight loading from model checkpoints.

FastLayerNorm: A platform-specific subclass of nn.LayerNorm that fuses dropout, residual addition, and layer normalization into a single kernel:

  • On CUDA: Uses the dropout_layer_norm extension's dropout_add_ln_fwd kernel for hidden states up to dimension 8192. Falls back to standard PyTorch LayerNorm for larger dimensions.
  • On ROCm: Performs residual addition in Python and delegates to the standard PyTorch super().forward().
  • On XPU: Uses Intel Extension for PyTorch (ipex.llm.functional.add_layer_norm).

FastRMSNorm: A custom nn.Module implementing Root Mean Square Layer Normalization with fused residual connections:

  • On CUDA: Uses dropout_layer_norm.dropout_add_ln_fwd with the RMSNorm flag (True) for hidden dimensions up to 8192. Falls back to a manual computation using torch.rsqrt for larger dimensions.
  • On ROCm: Uses vLLM's ops.rms_norm kernel.
  • On XPU: Uses ipex.llm.functional.add_rms_norm.
  • The load classmethod loads the weight tensor from a checkpoint prefix.

Both classes accept an optional residual tensor in their forward methods. When provided, the residual is added to the hidden states before normalization. Both return a tuple of (normed_hidden_states, residual) to enable residual stream propagation through the model.

Usage

These layers are used by all transformer model implementations in LoRAX as drop-in replacements for standard LayerNorm and RMSNorm. FastLayerNorm is used by models like GPT-2 and BLOOM, while FastRMSNorm is used by LLaMA, Mistral, and similar architectures.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/layers/layernorm.py
  • Lines: 1-182

Signature

class FastLayerNorm(nn.LayerNorm):
    def forward(self, hidden_states, residual=None):

class FastRMSNorm(nn.Module):
    def __init__(self, weight: torch.Tensor, eps: float):

Import

from lorax_server.layers.layernorm import FastLayerNorm, FastRMSNorm

I/O Contract

Inputs (FastLayerNorm.forward)

Name Type Required Description
hidden_states torch.Tensor Yes Input hidden states to normalize
residual torch.Tensor or None No Optional residual tensor to add before normalization

Inputs (FastRMSNorm.__init__)

Name Type Required Description
weight torch.Tensor Yes Learnable scale parameter for RMSNorm
eps float Yes Epsilon for numerical stability (variance_epsilon)

Outputs

Name Type Description
normed_hidden_states torch.Tensor Normalized hidden states
residual torch.Tensor Updated residual tensor for propagation to next layer

Usage Examples

# Used internally by transformer model layers
from lorax_server.layers.layernorm import FastRMSNorm

rms_norm = FastRMSNorm.load(prefix="model.layers.0.input_layernorm", weights=weights, eps=1e-6)
normed_hidden_states, residual = rms_norm(hidden_states, residual)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment