Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Normalization C API

From Leeroopedia
Revision as of 15:58, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Normalization_C_API.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Declares the C API for LayerNorm and RMSNorm forward and backward operations, with support for zero-centered gamma, optional cuDNN backend selection, and FP8 output.

Description

normalization.h provides the public normalization API:

  • nvte_layernorm_fwd: Standard LayerNorm forward: y = (x - E[x]) / sqrt(Var[x] + epsilon) * gamma + beta
  • nvte_layernorm_bwd: LayerNorm backward computing dx, dgamma, dbeta
  • nvte_rmsnorm_fwd: RMSNorm forward: y = x / RMS(x) * gamma
  • nvte_rmsnorm_bwd: RMSNorm backward computing dx, dgamma
  • nvte_rmsnorm_bwd_add: Fused RMSNorm backward with gradient addition

Configuration functions:

  • nvte_enable_cudnn_norm_fwd/bwd: Toggle cuDNN backends for forward/backward
  • NVTE_Norm_Type enum: LayerNorm vs RMSNorm

All functions accept workspace tensors using the empty-tensor-query pattern for workspace sizing, and support FP8 output with zero_centered_gamma option.

Usage

Use for all LayerNorm and RMSNorm operations in Transformer layers. These are the primary entry points called by framework bindings.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/common/include/transformer_engine/normalization.h
Lines
1--197

Signature

enum NVTE_Norm_Type { LayerNorm = 0, RMSNorm = 1 };

void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma,
                        const NVTETensor beta, const float epsilon,
                        NVTETensor z, NVTETensor mu, NVTETensor rsigma,
                        NVTETensor workspace, const int multiprocessorCount,
                        const bool zero_centered_gamma, cudaStream_t stream);

void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x,
                        const NVTETensor mu, const NVTETensor rsigma,
                        const NVTETensor gamma, NVTETensor dx,
                        NVTETensor dgamma, NVTETensor dbeta,
                        NVTETensor workspace, const int multiprocessorCount,
                        const bool zero_centered_gamma, cudaStream_t stream);

void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma,
                      const float epsilon, NVTETensor z, NVTETensor rsigma,
                      NVTETensor workspace, const int multiprocessorCount,
                      const bool zero_centered_gamma, cudaStream_t stream);

void nvte_rmsnorm_bwd(...);
void nvte_rmsnorm_bwd_add(...);
void nvte_enable_cudnn_norm_fwd(bool enable);
void nvte_enable_cudnn_norm_bwd(bool enable);

Import

#include "transformer_engine/normalization.h"

I/O Contract

Inputs

Name Type Required Description
x NVTETensor Yes Input tensor [N, H]
gamma NVTETensor Yes Gamma weight [H]
beta NVTETensor Yes (LayerNorm) Beta weight [H]
epsilon float Yes Numerical stability constant
zero_centered_gamma bool Yes Use gamma+1 instead of gamma

Outputs

Name Type Description
z NVTETensor Normalized output [N, H]
mu NVTETensor Mean [N] (LayerNorm only)
rsigma NVTETensor Inverse std deviation [N]

Usage Examples

#include "transformer_engine/normalization.h"

// LayerNorm forward
nvte_layernorm_fwd(x, gamma, beta, 1e-5f, z, mu, rsigma,
                   workspace, sm_count, /*zero_centered_gamma=*/false, stream);

// RMSNorm forward (preferred for LLaMA-style models)
nvte_rmsnorm_fwd(x, gamma, 1e-5f, z, rsigma,
                 workspace, sm_count, /*zero_centered_gamma=*/false, stream);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment