Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Turboderp org Exllamav2 ExLlamaV2RMSNorm

From Leeroopedia
Knowledge Sources
Domains Normalization, Model_Architecture
Last Updated 2026-02-15 00:00 GMT

Overview

ExLlamaV2RMSNorm implements Root Mean Square Layer Normalization with support for tensor parallelism, Gemma-style constant bias, and both native C extension and pure PyTorch execution paths.

Description

RMS Normalization normalizes the input tensor by dividing by the root mean square of its elements, then scaling by a learned weight vector. The mathematical formula is:

variance = x.pow(2).mean(-1, keepdim=True)

output = x * rsqrt(variance + epsilon) * weight

Unlike standard LayerNorm, RMSNorm does not subtract the mean and does not use a bias term. This makes it computationally cheaper while retaining similar training stability properties.

The module provides three forward paths:

  • forward() - Default path using the C extension (ext_c.rms_norm) for maximum performance. Supports optional FP32 output for mixed-precision scenarios.
  • forward_tp() - Tensor-parallel path that broadcasts hidden states across devices and applies the normalization using ext_c.rms_norm_tp. Returns a list of tensors, one per device.
  • forward_torch() - Pure PyTorch fallback implementation for debugging or environments without the C extension.

The module also handles Gemma-style models where the normalization weight tensor has a constant bias of 1.0 added during training. This bias is applied at load time via archparams.norm_constant_bias and subtracted back in get_weight() to return the original stored tensor.

Tensor parallelism is activated via tp_split(), which replicates the weight and optional bias tensors across devices and sets the broadcast_type for controlling how input states are distributed.

Usage

Use ExLlamaV2RMSNorm in any model layer that requires RMS normalization. It is automatically instantiated as the input normalization for decoder blocks (including parallel decoders), attention layers, and the final output normalization layer. It is selected over LayerNorm based on the model's archparams.norm setting.

Code Reference

Source Location

Signature

class ExLlamaV2RMSNorm(ExLlamaV2Module):

    name: str = "RMSNorm"

    weight: nn.Parameter | None | list[nn.Parameter | None]
    bias: nn.Parameter | None | list[nn.Parameter | None]
    variance_epsilon: float
    is_tp: bool
    broadcast_type: int | None

    def __init__(
        self,
        model,
        key,
        archparams=None,
    ): ...

    def load(self, device_context=True): ...
    def unload(self): ...
    def get_weight(self) -> torch.Tensor: ...
    def weight_footprint(self) -> int: ...
    def numel(self): ...

    def forward(
        self,
        hidden_states: torch.Tensor,
        cache=None,
        attn_params=None,
        past_len=None,
        intermediates: bool = False,
        loras=None,
        output_fp32=False,
        **kwargs
    ) -> torch.Tensor | dict[str: torch.Tensor]: ...

    def forward_tp(
        self,
        hidden_states: torch.Tensor,
        cache=None,
        attn_params=None,
        past_len=None,
        intermediates: bool = False,
        loras=None,
        output_fp32=False,
        **kwargs
    ) -> torch.Tensor | dict[str: torch.Tensor]: ...

    def forward_torch(
        self,
        hidden_states: torch.Tensor,
        cache=None,
        attn_params=None,
        past_len=None,
        intermediates: bool = False,
        loras=None,
        output_fp32=False,
        **kwargs
    ) -> torch.Tensor | dict[str: torch.Tensor]: ...

    def tp_split(self, broadcast_type: int): ...

Import

from exllamav2.rmsnorm import ExLlamaV2RMSNorm

I/O Contract

forward()

Parameter Type Description
hidden_states torch.Tensor Input tensor of shape (batch, seq_len, hidden_size)
cache None Unused; present for API compatibility
attn_params None Unused; present for API compatibility
past_len None Unused; present for API compatibility
intermediates bool If True, return {"hidden_states": tensor} instead of raw tensor
loras None Unused; present for API compatibility
output_fp32 bool If True, output tensor is in float32; otherwise float16
Return Type Description
hidden_states torch.Tensor Normalized tensor of same shape as input (when intermediates=False)
result dict dict {"hidden_states": tensor} (when intermediates=True)

forward_tp()

Return Type Description
outputs list[torch.Tensor] List of normalized tensors, one per tensor-parallel device

tp_split()

Parameter Type Description
broadcast_type int Controls how input hidden states are broadcast across TP devices

Usage Examples

from exllamav2.rmsnorm import ExLlamaV2RMSNorm

# Construction (typically handled by model/decoder layer)
norm = ExLlamaV2RMSNorm(model=model, key="model.layers.0.input_layernorm")
norm.load()

# Standard forward pass
normalized = norm.forward(hidden_states)

# Forward with FP32 output for numerical precision
normalized_fp32 = norm.forward(hidden_states, output_fp32=True)

# Pure PyTorch fallback (useful for debugging)
normalized_torch = norm.forward_torch(hidden_states)

# Tensor-parallel setup
norm.tp_split(broadcast_type=0)
tp_outputs = norm.forward(hidden_states)  # returns list of tensors

# Retrieve the original stored weight (reverting Gemma bias if any)
raw_weight = norm.get_weight()

Related Pages

Implements Principle

Requires Environment

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment