Implementation:NVIDIA TransformerEngine JAX LayerNorm
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Normalization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides a standalone differentiable layer normalization function supporting both LayerNorm and RMSNorm with optional FP8 quantization output.
Description
canonicalize_norm_type normalizes string inputs to "layernorm" or "rmsnorm". The public layernorm() delegates to _layernorm which uses jax.custom_vjp for custom differentiation. The forward rule calls tex.normalization_fwd (C++ extension) with optional quantizer, then dequantizes the output. The backward rule calls tex.normalization_bwd to compute gradients for input, gamma, and beta. The quantizer is passed through as a differentiable argument for delayed scaling state management.
This is a fundamental normalization building block used by layernorm_dense and layernorm_mlp for standalone normalization, and by the Flax LayerNorm module for direct use.
Usage
Use this function for standalone layer normalization in JAX. It supports both LayerNorm and RMSNorm and integrates with FP8 quantization for training with low-precision formats.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/layernorm.py- Lines
- 1--141
Signature
def canonicalize_norm_type(norm_type: str) -> str:
"""Normalize norm type string to 'layernorm' or 'rmsnorm'."""
...
def layernorm(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
norm_type: str,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
quantizer: Quantizer = None,
) -> jnp.ndarray: ...
Import
from transformer_engine.jax.layernorm import layernorm, canonicalize_norm_type
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | jnp.ndarray |
Yes | Input tensor to normalize |
| gamma | jnp.ndarray |
Yes | Scale parameter |
| beta | jnp.ndarray |
Yes | Shift parameter (ignored for RMSNorm) |
| norm_type | str |
Yes | Normalization type: "layernorm" or "rmsnorm" |
| zero_centered_gamma | bool |
No | Whether gamma is zero-centered (default False) |
| epsilon | float |
No | Numerical stability constant (default 1e-6) |
| quantizer | Quantizer |
No | Optional FP8 quantizer for fused quantization |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Normalized tensor (dequantized if quantizer was used) |
Usage Examples
from transformer_engine.jax.layernorm import layernorm
import jax.numpy as jnp
# Apply LayerNorm
output = layernorm(x, gamma, beta, norm_type="layernorm", epsilon=1e-5)
# Apply RMSNorm
output = layernorm(x, gamma, beta, norm_type="rmsnorm", epsilon=1e-5)
# LayerNorm with FP8 quantization
output = layernorm(x, gamma, beta, norm_type="layernorm",
quantizer=fp8_quantizer)