Implementation:LLMBook zh LLMBook zh github io LlamaRMSNorm
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Model_Architecture |
| Last Updated | 2026-02-08 04:29 GMT |
Overview
Concrete tool for RMS normalization of hidden states provided by PyTorch as a custom nn.Module.
Description
The LlamaRMSNorm class implements Root Mean Square normalization as a PyTorch module. It computes the RMS of the input hidden states, divides by the RMS value (with epsilon for numerical stability), and applies a learnable scale parameter. The computation is performed in float32 precision for numerical stability, then cast back to the input dtype. This is used as the normalization layer throughout the LLaMA architecture, appearing before self-attention and before the MLP in each decoder layer, as well as after the final decoder layer.
Usage
Import this class when implementing or studying the LLaMA model architecture. It is instantiated with the model's hidden size and is applied at multiple points within each decoder layer and once at the model output.
Code Reference
Source Location
- Repository: LLMBook-zh
- File: code/5.1 RMSNorm.py
- Lines: 1-14
Signature
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Args:
hidden_size: Dimension of the hidden states to normalize.
eps: Small constant for numerical stability (default 1e-6).
"""
def forward(self, hidden_states):
"""
Args:
hidden_states: Input tensor of shape (..., hidden_size).
Returns:
Normalized tensor of the same shape, scaled by learnable weight.
"""
Import
from torch import nn
# LlamaRMSNorm is defined locally in code/5.1 RMSNorm.py
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_size | int | Yes | Dimension of hidden states (constructor) |
| eps | float | No | Numerical stability epsilon (default 1e-6) |
| hidden_states | torch.Tensor | Yes | Input tensor of shape (..., hidden_size) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | Normalized and scaled tensor of same shape as input |
Usage Examples
import torch
from torch import nn
# Instantiate RMSNorm for a hidden size of 4096
norm = LlamaRMSNorm(hidden_size=4096, eps=1e-6)
# Apply to hidden states (batch_size=2, seq_len=128, hidden_size=4096)
hidden_states = torch.randn(2, 128, 4096)
normalized = norm(hidden_states)
# normalized.shape == (2, 128, 4096)