Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX LayerNorm

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Normalization
Last Updated 2026-02-07 14:00 GMT

Overview

Provides a standalone differentiable layer normalization function supporting both LayerNorm and RMSNorm with optional FP8 quantization output.

Description

canonicalize_norm_type normalizes string inputs to "layernorm" or "rmsnorm". The public layernorm() delegates to _layernorm which uses jax.custom_vjp for custom differentiation. The forward rule calls tex.normalization_fwd (C++ extension) with optional quantizer, then dequantizes the output. The backward rule calls tex.normalization_bwd to compute gradients for input, gamma, and beta. The quantizer is passed through as a differentiable argument for delayed scaling state management.

This is a fundamental normalization building block used by layernorm_dense and layernorm_mlp for standalone normalization, and by the Flax LayerNorm module for direct use.

Usage

Use this function for standalone layer normalization in JAX. It supports both LayerNorm and RMSNorm and integrates with FP8 quantization for training with low-precision formats.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/layernorm.py
Lines
1--141

Signature

def canonicalize_norm_type(norm_type: str) -> str:
    """Normalize norm type string to 'layernorm' or 'rmsnorm'."""
    ...

def layernorm(
    x: jnp.ndarray,
    gamma: jnp.ndarray,
    beta: jnp.ndarray,
    norm_type: str,
    zero_centered_gamma: bool = False,
    epsilon: float = 1e-6,
    quantizer: Quantizer = None,
) -> jnp.ndarray: ...

Import

from transformer_engine.jax.layernorm import layernorm, canonicalize_norm_type

I/O Contract

Inputs

Name Type Required Description
x jnp.ndarray Yes Input tensor to normalize
gamma jnp.ndarray Yes Scale parameter
beta jnp.ndarray Yes Shift parameter (ignored for RMSNorm)
norm_type str Yes Normalization type: "layernorm" or "rmsnorm"
zero_centered_gamma bool No Whether gamma is zero-centered (default False)
epsilon float No Numerical stability constant (default 1e-6)
quantizer Quantizer No Optional FP8 quantizer for fused quantization

Outputs

Name Type Description
output jnp.ndarray Normalized tensor (dequantized if quantizer was used)

Usage Examples

from transformer_engine.jax.layernorm import layernorm
import jax.numpy as jnp

# Apply LayerNorm
output = layernorm(x, gamma, beta, norm_type="layernorm", epsilon=1e-5)

# Apply RMSNorm
output = layernorm(x, gamma, beta, norm_type="rmsnorm", epsilon=1e-5)

# LayerNorm with FP8 quantization
output = layernorm(x, gamma, beta, norm_type="layernorm",
                   quantizer=fp8_quantizer)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment