Principle:NVIDIA TransformerEngine Fused LayerNorm MLP
Overview
Fusing layer normalization and the entire MLP sub-layer (two linear transforms with activation) into a single optimized module.
Description
The Transformer feed-forward network (FFN) consists of a LayerNorm, a first linear projection (up-projection), an activation function, and a second linear projection (down-projection). In a naive implementation, each of these operations is a separate kernel launch with intermediate tensors written to and read from global memory.
LayerNormMLP combines all of these operations into a fused module. This eliminates multiple memory round-trips and enables FP8 quantization of the intermediate MLP tensors, which are typically the largest activations in a Transformer layer.
Key benefits of this fusion:
- Eliminates intermediate memory traffic between LayerNorm, FC1, activation, and FC2.
- Enables FP8 quantization of the large intermediate activation tensor (of size
ffn_hidden_size, typically 4x the model hidden size). - Reduces kernel launch overhead by consolidating multiple operations.
- Supports gated activations (SwiGLU, GeGLU) natively, where the first projection output is split into gate and value paths.
Theoretical Basis
The mathematical formulation of the fused MLP operation is:
MLP(x) = W2 * activation(W1 * LayerNorm(x) + b1) + b2
Step by step:
- LayerNorm:
norm = (x - mean) / sqrt(var + eps) * gamma + beta - FC1 (up-projection):
h = W1 * norm + b1 - Activation:
a = activation(h) - FC2 (down-projection):
y = W2 * a + b2
Supported activation functions:
| Activation | Formula | Gated? |
|---|---|---|
gelu |
GELU(x) |
No |
geglu |
GELU(x1) * x2 (GeGLU) |
Yes |
silu |
SiLU(x) (also known as Swish) |
No |
swiglu |
SiLU(x1) * x2 (SwiGLU) |
Yes |
relu |
ReLU(x) |
No |
srelu |
Squared ReLU: ReLU(x)^2 |
No |
qgelu |
Quick GELU approximation |
No |
For gated activations (SwiGLU, GeGLU), the FC1 output dimension is doubled: W1 projects to 2 * ffn_hidden_size, and the output is split into gate and value paths before the element-wise product.
Usage
Use the fused LayerNorm + MLP operation when:
- The Transformer FFN layer follows a LayerNorm -- this is the standard architecture pattern in virtually all Transformer models.
- You need the largest performance gain from fusion, since the MLP sub-layer typically accounts for the majority of FLOPs in a Transformer layer.
- You want to use gated activations (SwiGLU, GeGLU) as used in LLaMA, PaLM, and other modern architectures.
- You are enabling FP8 training and want the large intermediate activations quantized automatically.
This fusion provides the largest single-module performance improvement in the TE optimization path.