Implementation:Microsoft Onnxruntime CUDA LayerNormGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing layer normalization gradients in the ONNX Runtime CUDA training framework.
Description
Implements three gradient variants for CUDA: LayerNormalizationGrad, InvertibleLayerNormalizationGrad, and SimplifiedLayerNormalizationGrad. The LayerNormGrad template class is parameterized by types T, U, V and a boolean simplified flag. The constructor reads the axis attribute to determine the normalization axis. ComputeInternal takes upstream gradient dY, input X (or output Y for invertible variant), scale, mean, and inverse standard deviation, then calls HostApplyLayerNormGrad to compute gradients for input (d_X), scale (d_scale), and bias (d_bias). The simplified variant (RMSNorm) does not compute bias gradients. The invertible variant reconstructs the input from the output, mean, and inv_std_dev. Registered type combinations include float, double, MLFloat16, and BFloat16 with float statistics.
Usage
Invoked during the backward pass of training when the model uses LayerNorm, InvertibleLayerNorm, or SimplifiedLayerNorm (RMSNorm) layers.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc
- Lines: 1-147
Signature
template <typename T, typename U, typename V, bool simplified>
class LayerNormGrad : public CudaKernel {
LayerNormGrad(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* p_op_kernel_context) const;
};
template <typename T, typename U, typename V>
using InvertibleLayerNormGrad = LayerNormGrad<T, U, V, false>;
Import
#include "orttraining/training_ops/cuda/nn/layer_norm.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| Y_grad (dY) | Tensor(T) | Yes | Upstream gradient |
| X (or Y) | Tensor(T) | Yes | Original input (or output for invertible variant) |
| scale | Tensor(V) | Yes | Layer normalization scale parameter |
| mean | Tensor(U) | Yes | Mean from forward pass |
| inv_std_dev | Tensor(U) | Yes | Inverse standard deviation from forward pass |
Outputs
| Name | Type | Description |
|---|---|---|
| X_grad (dX) | Tensor(T) | Gradient with respect to input |
| scale_grad | Tensor(V) | Gradient with respect to scale |
| bias_grad | Tensor(V) | Gradient with respect to bias (not computed for simplified variant) |
Usage Examples
// Registration for MLFloat16 input, float statistics, MLFloat16 scale
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16, float, MLFloat16)
// This registers LayerNormalizationGrad, InvertibleLayerNormalizationGrad,
// and SimplifiedLayerNormalizationGrad for the given type combination