Implementation:NVIDIA TransformerEngine TE LayerNorm
| Field | Value |
|---|---|
| Sources | TransformerEngine, Layer Normalization |
| Domains | Deep_Learning, Normalization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
te.LayerNorm is a concrete tool for hardware-accelerated layer normalization provided by NVIDIA's TransformerEngine library. It replaces torch.nn.LayerNorm with a fused CUDA kernel implementation that supports zero-centered gamma initialization and optional FP8 output casting.
Description
te.LayerNorm applies layer normalization over a mini-batch of inputs as described in the Layer Normalization paper. The computation normalizes across the inner-most dimensions specified by normalized_shape:
y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
The class inherits from the internal _LayerNormOp (located in transformer_engine.pytorch.ops), which implements the fused CUDA kernel dispatch. The public LayerNorm class adds:
- Legacy parameter handling: Supports the deprecated
hidden_sizeargument (renamed tonormalized_shape) and the deprecatedparams_dtypeargument (renamed todtype) for backward compatibility with older Megatron-LM integration code. - Sequence parallelism flag: Sets a
sequence_parallelattribute on the weight and bias parameters for custom Megatron-LM integration logic. - SM margin control: Configurable
sm_marginparameter to reserve streaming multiprocessors for concurrent operations such as communication kernels.
Zero-Centered Gamma
When zero_centered_gamma=True, gamma is initialized to zero and the computation becomes:
y = (x - E[x]) / sqrt(Var[x] + eps) * (1 + gamma) + beta
This means the initial forward pass performs pure normalization (identity scaling), which can improve training stability for deep models.
Usage
Import te.LayerNorm when replacing torch.nn.LayerNorm in models that will be trained with TransformerEngine's FP8 autocast or when fused normalization kernels are desired for performance. It is a direct drop-in replacement with compatible constructor arguments.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/module/layernorm.py- Class
LayerNorm- Lines
- __init__ at L59--68
Signature
class LayerNorm(torch.nn.Module):
def __init__(
self,
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:
Import
from transformer_engine.pytorch import LayerNorm
# or equivalently:
import transformer_engine.pytorch as te
te.LayerNorm
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
input |
torch.Tensor |
Yes | Input tensor of any shape; normalization is applied over the last D dimensions matching normalized_shape
|
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Normalized tensor of the same shape as input |
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
normalized_shape |
int or iterable of int | required | Inner dimensions of the input tensor over which to normalize |
eps |
float | 1e-5 |
Small constant added to the denominator for numerical stability |
zero_centered_gamma |
bool | False |
If True, gamma is initialized to zero and the formula uses (1 + gamma) scaling
|
device |
torch.device |
default CUDA device | Device on which to allocate the learnable parameters (passed via **kwargs)
|
dtype |
torch.dtype |
default dtype | Data type of the learnable parameters (passed via **kwargs)
|
sm_margin |
int or dict | 0 |
Number of SMs to exclude from kernel launches; accepts a dict with keys "forward", "backward", "inference" for fine-grained control (passed via **kwargs)
|
Usage Examples
Basic Drop-in Replacement
import torch
import transformer_engine.pytorch as te
# Before: standard PyTorch
# layer_norm = torch.nn.LayerNorm(768)
# After: TransformerEngine drop-in replacement
layer_norm = te.LayerNorm(768)
# Usage is identical
output = layer_norm(input_tensor)
With Zero-Centered Gamma
import transformer_engine.pytorch as te
# Zero-centered gamma for improved training stability
layer_norm = te.LayerNorm(768, zero_centered_gamma=True)
# gamma starts at zero, so initial output is pure normalization
output = layer_norm(input_tensor)
Inside FP8 Training
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
layer_norm = te.LayerNorm(768)
linear = te.Linear(768, 3072)
# LayerNorm output can be directly consumed by FP8 linear
with te.fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
normed = layer_norm(input_tensor)
output = linear(normed)