Implementation:NVIDIA TransformerEngine JAX Cpp Normalization
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Normalization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements JAX custom primitives for fused layer normalization and RMS normalization operations with optional FP8 quantization, supporting both TE and cuDNN backends.
Description
NormFwdPrimitive computes normalization forward pass (LayerNorm or RMSNorm based on NVTE_Norm_Type) with optional fused quantization output in FP8/MXFP8 formats. NormBwdPrimitive computes gradients for gamma, beta, and input. Supports cuDNN-accelerated normalization via the NVTE_NORM_FWD_USE_CUDNN environment variable, zero-centered gamma, and configurable SM margins. The high-level functions (normalization_fwd, normalization_bwd) handle scale/amax management for different scaling modes and delegate to the primitives.
Fusing normalization with quantization eliminates an extra memory read/write pass, which is especially important for FP8 training where normalization output feeds directly into quantized GEMM inputs.
Usage
Use this module indirectly through layernorm(), layernorm_dense(), or layernorm_mlp(). Direct usage is for custom implementations that need explicit control over normalization forward/backward with FP8 quantization fusion.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/cpp_extensions/normalization.py- Lines
- 1--1584
Signature
class NormFwdPrimitive(BasePrimitive): ...
class NormBwdPrimitive(BasePrimitive): ...
def layernorm_fwd(
x: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool = False,
epsilon: float = 1e-6,
quantizer: Quantizer = None,
) -> Tuple: ...
def layernorm_bwd(
dz: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, rsigma: jnp.ndarray,
gamma: jnp.ndarray, zero_centered_gamma: bool = False, epsilon: float = 1e-6,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ...
def rmsnorm_fwd(
x: jnp.ndarray, gamma: jnp.ndarray,
zero_centered_gamma: bool = False, epsilon: float = 1e-6,
quantizer: Quantizer = None,
) -> Tuple: ...
def rmsnorm_bwd(
dz: jnp.ndarray, x: jnp.ndarray, rsigma: jnp.ndarray,
gamma: jnp.ndarray, zero_centered_gamma: bool = False, epsilon: float = 1e-6,
) -> Tuple[jnp.ndarray, jnp.ndarray]: ...
def normalization_fwd(
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray,
norm_type: str, zero_centered_gamma: bool, epsilon: float,
quantizer: Quantizer = None,
) -> Tuple: ...
def normalization_bwd(
norm_type: str, ...,
) -> Tuple: ...
Import
from transformer_engine.jax.cpp_extensions.normalization import normalization_fwd, normalization_bwd
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | jnp.ndarray |
Yes | Input tensor to normalize |
| gamma | jnp.ndarray |
Yes | Scale parameter for normalization |
| beta | jnp.ndarray |
Yes (LayerNorm) | Shift parameter (LayerNorm only) |
| norm_type | str |
Yes | Normalization type: "layernorm" or "rmsnorm" |
| zero_centered_gamma | bool |
No | Whether gamma is zero-centered |
| epsilon | float |
No | Small constant for numerical stability (default 1e-6) |
| quantizer | Quantizer |
No | Optional FP8 quantizer for fused quantization |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Union[jnp.ndarray, ScaledTensor] |
Normalized output, optionally FP8 quantized |
| mu | jnp.ndarray |
Mean (LayerNorm only, for backward pass) |
| rsigma | jnp.ndarray |
Reciprocal standard deviation (for backward pass) |
Usage Examples
from transformer_engine.jax.cpp_extensions.normalization import normalization_fwd, normalization_bwd
# Forward pass with LayerNorm + FP8 quantization
output, mu, rsigma = normalization_fwd(
x, gamma, beta, norm_type="layernorm",
zero_centered_gamma=False, epsilon=1e-5,
quantizer=fp8_quantizer
)
# Backward pass
dx, dgamma, dbeta = normalization_bwd(
norm_type="layernorm", dz=grad_output, x=x, mu=mu, rsigma=rsigma,
gamma=gamma, zero_centered_gamma=False, epsilon=1e-5
)