Implementation:Turboderp org Exllamav2 ExLlamaV2RMSNorm
| Knowledge Sources | |
|---|---|
| Domains | Normalization, Model_Architecture |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
ExLlamaV2RMSNorm implements Root Mean Square Layer Normalization with support for tensor parallelism, Gemma-style constant bias, and both native C extension and pure PyTorch execution paths.
Description
RMS Normalization normalizes the input tensor by dividing by the root mean square of its elements, then scaling by a learned weight vector. The mathematical formula is:
variance = x.pow(2).mean(-1, keepdim=True)
output = x * rsqrt(variance + epsilon) * weight
Unlike standard LayerNorm, RMSNorm does not subtract the mean and does not use a bias term. This makes it computationally cheaper while retaining similar training stability properties.
The module provides three forward paths:
- forward() - Default path using the C extension (
ext_c.rms_norm) for maximum performance. Supports optional FP32 output for mixed-precision scenarios. - forward_tp() - Tensor-parallel path that broadcasts hidden states across devices and applies the normalization using
ext_c.rms_norm_tp. Returns a list of tensors, one per device. - forward_torch() - Pure PyTorch fallback implementation for debugging or environments without the C extension.
The module also handles Gemma-style models where the normalization weight tensor has a constant bias of 1.0 added during training. This bias is applied at load time via archparams.norm_constant_bias and subtracted back in get_weight() to return the original stored tensor.
Tensor parallelism is activated via tp_split(), which replicates the weight and optional bias tensors across devices and sets the broadcast_type for controlling how input states are distributed.
Usage
Use ExLlamaV2RMSNorm in any model layer that requires RMS normalization. It is automatically instantiated as the input normalization for decoder blocks (including parallel decoders), attention layers, and the final output normalization layer. It is selected over LayerNorm based on the model's archparams.norm setting.
Code Reference
Source Location
- Repository: Turboderp_org_Exllamav2
- File: exllamav2/rmsnorm.py
- Lines: 1-239
Signature
class ExLlamaV2RMSNorm(ExLlamaV2Module):
name: str = "RMSNorm"
weight: nn.Parameter | None | list[nn.Parameter | None]
bias: nn.Parameter | None | list[nn.Parameter | None]
variance_epsilon: float
is_tp: bool
broadcast_type: int | None
def __init__(
self,
model,
key,
archparams=None,
): ...
def load(self, device_context=True): ...
def unload(self): ...
def get_weight(self) -> torch.Tensor: ...
def weight_footprint(self) -> int: ...
def numel(self): ...
def forward(
self,
hidden_states: torch.Tensor,
cache=None,
attn_params=None,
past_len=None,
intermediates: bool = False,
loras=None,
output_fp32=False,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]: ...
def forward_tp(
self,
hidden_states: torch.Tensor,
cache=None,
attn_params=None,
past_len=None,
intermediates: bool = False,
loras=None,
output_fp32=False,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]: ...
def forward_torch(
self,
hidden_states: torch.Tensor,
cache=None,
attn_params=None,
past_len=None,
intermediates: bool = False,
loras=None,
output_fp32=False,
**kwargs
) -> torch.Tensor | dict[str: torch.Tensor]: ...
def tp_split(self, broadcast_type: int): ...
Import
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
I/O Contract
forward()
| Parameter | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor |
Input tensor of shape (batch, seq_len, hidden_size)
|
| cache | None | Unused; present for API compatibility |
| attn_params | None | Unused; present for API compatibility |
| past_len | None | Unused; present for API compatibility |
| intermediates | bool |
If True, return {"hidden_states": tensor} instead of raw tensor
|
| loras | None | Unused; present for API compatibility |
| output_fp32 | bool |
If True, output tensor is in float32; otherwise float16 |
| Return | Type | Description |
|---|---|---|
| hidden_states | torch.Tensor |
Normalized tensor of same shape as input (when intermediates=False) |
| result dict | dict |
{"hidden_states": tensor} (when intermediates=True)
|
forward_tp()
| Return | Type | Description |
|---|---|---|
| outputs | list[torch.Tensor] |
List of normalized tensors, one per tensor-parallel device |
tp_split()
| Parameter | Type | Description |
|---|---|---|
| broadcast_type | int |
Controls how input hidden states are broadcast across TP devices |
Usage Examples
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
# Construction (typically handled by model/decoder layer)
norm = ExLlamaV2RMSNorm(model=model, key="model.layers.0.input_layernorm")
norm.load()
# Standard forward pass
normalized = norm.forward(hidden_states)
# Forward with FP32 output for numerical precision
normalized_fp32 = norm.forward(hidden_states, output_fp32=True)
# Pure PyTorch fallback (useful for debugging)
normalized_torch = norm.forward_torch(hidden_states)
# Tensor-parallel setup
norm.tp_split(broadcast_type=0)
tp_outputs = norm.forward(hidden_states) # returns list of tensors
# Retrieve the original stored weight (reverting Gemma bias if any)
raw_weight = norm.get_weight()
Related Pages
Implements Principle
Requires Environment
Related
- Turboderp_org_Exllamav2_ExLlamaV2ParallelDecoder - Parallel decoder block that uses RMSNorm as its shared input normalization
- Turboderp_org_Exllamav2_RoPE - Rotary embeddings applied after normalization in the attention path