Principle:NVIDIA TransformerEngine Fused LayerNorm Linear
Overview
Fusing layer normalization and linear transformation into a single GPU kernel for reduced memory traffic and latency.
Description
Instead of executing LayerNorm and Linear as separate operations (requiring an intermediate write to global memory), LayerNormLinear fuses both into one kernel launch. This eliminates the memory round-trip for the normalized tensor, improving both throughput and memory efficiency.
In a standard Transformer architecture, QKV projections are typically preceded by a LayerNorm. When these two operations are run separately, the normalized output must be written to GPU global memory and then read back for the subsequent linear projection. This round-trip is a bottleneck, particularly for memory-bandwidth-bound workloads.
By fusing LayerNorm and Linear into a single kernel:
- The intermediate normalized tensor stays in GPU registers or shared memory, never hitting global memory.
- One kernel launch replaces two, reducing kernel launch overhead.
- The fused operation enables better instruction-level parallelism on the GPU.
Theoretical Basis
The mathematical formulation of the fused operation is:
y = Linear(LayerNorm(x))
Separately executed:
- LayerNorm:
norm = (x - mean) / sqrt(var + eps) * gamma + beta - Linear:
y = norm * W^T + bias
Fused execution:
Both computations happen in a single kernel pass over x, writing only the final output y to global memory. The intermediate norm tensor is never materialized in global memory.
| Aspect | Separate Execution | Fused Execution |
|---|---|---|
| Kernel launches | 2 (LayerNorm + Linear) | 1 (LayerNormLinear) |
| Global memory writes | 2 (norm output + linear output) | 1 (linear output only) |
| Global memory reads | 3 (LN input + norm for Linear + weights) | 2 (input + weights) |
| Intermediate tensor | Materialized in global memory | Stays in registers/shared memory |
Usage
Use the fused LayerNorm + Linear operation when:
- A Linear layer immediately follows a LayerNorm -- the most common case is QKV projections in Transformer self-attention.
- You are optimizing for memory bandwidth on GPU workloads.
- You want to leverage FP8 quantization of the linear output (available through the TE implementation).
- You are following TE's progressive optimization path, replacing separate
torch.nn.LayerNorm+torch.nn.Linearwith a single fused module.
This fusion is part of the first optimization step in TransformerEngine's getting-started tutorial, where users replace PyTorch primitives with TE fused modules before enabling FP8 training.