Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Ops RMSNorm

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Fusible Root Mean Square Layer Normalization operation with CUDA-accelerated kernels, configurable SM margins, and ONNX export support.

Description

RMSNorm is a BasicOperation implementing RMS normalization: y = x / sqrt(Var[x] + eps) * gamma. It uses CUDA kernels (rmsnorm_fwd and rmsnorm_bwd from transformer_engine_torch) for efficient computation. Unlike LayerNorm, it has no bias parameter and no mean subtraction. Features include zero_centered_gamma mode, configurable SM margins for kernel overlap, quantizer integration for the output, ONNX export support, and CPU activation offloading.

Usage

Used as the fundamental normalization layer in LLaMA-style and other modern transformer architectures that prefer RMSNorm over LayerNorm.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/ops/basic/rmsnorm.py
Lines
1--254

Signature

class RMSNorm(BasicOperation):
    def __init__(self, normalized_shape, *, eps=1e-5, device=None, dtype=None, zero_centered_gamma=False, sm_margin=0) -> None: ...
    def reset_parameters(self) -> None: ...
    def op_forward(self, ctx, input_, prev_op_grad_output_quantizer, next_op_input_quantizer) -> torch.Tensor: ...
    def op_backward(self, ctx, grad_output) -> Tuple[torch.Tensor, Tuple]: ...

Import

from transformer_engine.pytorch.ops.basic.rmsnorm import RMSNorm

I/O Contract

Inputs

Name Type Required Description
normalized_shape int or Iterable[int] Yes Inner dimensions of input tensor
eps float No Numerical stability constant (default 1e-5)
zero_centered_gamma bool No Use zero-centered gamma initialization

Outputs

Name Type Description
output torch.Tensor RMS-normalized tensor

Usage Examples

from transformer_engine.pytorch.ops.basic.rmsnorm import RMSNorm

rmsnorm = RMSNorm(4096, eps=1e-5, zero_centered_gamma=True)
output = rmsnorm.op_forward(ctx, input_tensor, None, quantizer)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment