Principle:NVIDIA TransformerEngine Baseline Transformer Layer
Overview
Establishing a pure PyTorch Transformer implementation as a performance baseline before TransformerEngine optimization.
Description
A standard Transformer layer built entirely from PyTorch primitives (torch.nn.Linear, torch.nn.LayerNorm, etc.) serves as the reference point for measuring TransformerEngine's performance improvements. This baseline uses no FP8 quantization, no fused kernels, and no custom CUDA code.
The baseline implementation is important because it:
- Establishes a performance floor that TE optimizations must beat to justify their complexity.
- Provides a correctness reference -- TE modules should produce numerically equivalent (or near-equivalent) outputs to the pure PyTorch baseline.
- Demonstrates the progressive optimization path from simple, readable PyTorch code to fully optimized TE modules, allowing developers to understand each optimization step in isolation.
- Serves as a starting point for TE adoption -- the getting-started tutorial begins with this baseline and progressively replaces components with TE equivalents.
In a standard baseline Transformer layer, each operation (LayerNorm, QKV projection, attention, output projection, MLP FC1, GELU, MLP FC2) is a separate torch.nn module, each launching its own CUDA kernel and writing intermediate results to global memory.
Theoretical Basis
The baseline follows the standard pre-norm Transformer architecture:
Self-Attention Sub-Layer:
norm1 = LayerNorm(x)Q, K, V = split(W_qkv * norm1, 3)attn = softmax(Q * K^T / sqrt(d_k)) * Vproj = W_out * attnh = x + Dropout(proj)
Feed-Forward Sub-Layer:
norm2 = LayerNorm(h)fc1 = GELU(W1 * norm2 + b1)fc2 = W2 * fc1 + b2output = h + fc2
Each of these steps uses a separate torch.nn module:
| Step | PyTorch Module | Kernel Launches |
|---|---|---|
| LayerNorm1 | torch.nn.LayerNorm |
1 |
| QKV Projection | torch.nn.Linear |
1 |
| Attention | Manual dot-product + softmax | 2-3 |
| Output Projection | torch.nn.Linear |
1 |
| Dropout | torch.nn.Dropout |
1 |
| Residual Add | Element-wise add | 1 |
| LayerNorm2 | torch.nn.LayerNorm |
1 |
| MLP FC1 | torch.nn.Linear |
1 |
| GELU | torch.nn.functional.gelu |
1 |
| MLP FC2 | torch.nn.Linear |
1 |
| Residual Add | Element-wise add | 1 |
| Total | ~12+ |
By contrast, a fully fused TE layer can reduce this to as few as 4-5 kernel launches with FP8 support.
Usage
Use the baseline Transformer implementation when:
- Starting the TE adoption journey -- understand what your model looks like in pure PyTorch before optimizing.
- Benchmarking TE improvements -- compare latency and throughput against the baseline to quantify gains.
- Verifying numerical correctness -- ensure TE modules produce equivalent outputs to the PyTorch reference.
- Teaching or documentation -- the baseline serves as a clear, readable reference implementation of the Transformer architecture.
The progressive optimization path is:
- Step 0: Pure PyTorch baseline (this principle).
- Step 1: Replace PyTorch modules with TE fused modules (
te.LayerNormLinear,te.LayerNormMLP, etc.). - Step 2: Enable FP8 training with
te.fp8_autocast(). - Step 3: Use the complete
te.TransformerLayerfor maximum optimization.