Implementation:NVIDIA TransformerEngine Ops RMSNorm
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fusible Root Mean Square Layer Normalization operation with CUDA-accelerated kernels, configurable SM margins, and ONNX export support.
Description
RMSNorm is a BasicOperation implementing RMS normalization: y = x / sqrt(Var[x] + eps) * gamma. It uses CUDA kernels (rmsnorm_fwd and rmsnorm_bwd from transformer_engine_torch) for efficient computation. Unlike LayerNorm, it has no bias parameter and no mean subtraction. Features include zero_centered_gamma mode, configurable SM margins for kernel overlap, quantizer integration for the output, ONNX export support, and CPU activation offloading.
Usage
Used as the fundamental normalization layer in LLaMA-style and other modern transformer architectures that prefer RMSNorm over LayerNorm.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/basic/rmsnorm.py- Lines
- 1--254
Signature
class RMSNorm(BasicOperation):
def __init__(self, normalized_shape, *, eps=1e-5, device=None, dtype=None, zero_centered_gamma=False, sm_margin=0) -> None: ...
def reset_parameters(self) -> None: ...
def op_forward(self, ctx, input_, prev_op_grad_output_quantizer, next_op_input_quantizer) -> torch.Tensor: ...
def op_backward(self, ctx, grad_output) -> Tuple[torch.Tensor, Tuple]: ...
Import
from transformer_engine.pytorch.ops.basic.rmsnorm import RMSNorm
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| normalized_shape | int or Iterable[int] | Yes | Inner dimensions of input tensor |
| eps | float | No | Numerical stability constant (default 1e-5) |
| zero_centered_gamma | bool | No | Use zero-centered gamma initialization |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | RMS-normalized tensor |
Usage Examples
from transformer_engine.pytorch.ops.basic.rmsnorm import RMSNorm
rmsnorm = RMSNorm(4096, eps=1e-5, zero_centered_gamma=True)
output = rmsnorm.op_forward(ctx, input_tensor, None, quantizer)