Principle:NVIDIA TransformerEngine FP8 Quantization
| Field | Value |
|---|---|
| Page Type | Principle |
| Repository | NVIDIA TransformerEngine |
| Domains | Deep_Learning, Quantization |
| Sources | TransformerEngine, FP8 Formats for Deep Learning |
| Implemented By | Implementation:NVIDIA_TransformerEngine_TE_Autocast |
Overview
Enabling low-precision FP8 training by wrapping forward passes in a quantization context manager.
Description
FP8 quantization reduces the precision of activations, weights, and gradients from FP16/BF16 to 8-bit floating point formats (E4M3 for forward, E5M2 for backward in HYBRID mode). This leverages Tensor Core acceleration on Hopper+ GPUs for approximately 2x throughput improvement.
In TransformerEngine, FP8 quantization is activated by wrapping the forward pass of TE modules inside a context manager (te.autocast). The context manager configures global FP8 state, enabling all TE layers within its scope to execute their GEMM operations in FP8 precision. On exit, it triggers amax reduction across distributed groups and updates scaling factors for the next iteration.
The key insight is that FP8 training does not require the entire computation graph to be in FP8. Instead, only the most compute-intensive operations (matrix multiplications) are executed in FP8, while accumulations and other operations remain in higher precision. This mixed-precision approach preserves model quality while capturing the throughput benefits of reduced-precision arithmetic.
Usage
Use when training on NVIDIA Hopper (H100) or newer GPUs to accelerate training throughput. FP8 quantization requires:
- Hardware: NVIDIA Hopper (H100) or Blackwell (B200) architecture GPUs with FP8 Tensor Core support.
- Recipe: A
Recipeobject (e.g.,DelayedScaling,Float8CurrentScaling) that configures the scaling strategy. - TE Modules: Model layers must use TransformerEngine modules (e.g.,
te.TransformerLayer,te.Linear) rather than standard PyTorch modules.
FP8 quantization is appropriate for:
- Large-scale transformer training where throughput is critical.
- Fine-tuning workloads where reduced precision does not degrade convergence.
- Inference serving where latency and throughput improvements are desired.
Theoretical Basis
FP8 formats trade precision for throughput by reducing the number of bits used to represent floating point values from 16 to 8. Two complementary FP8 formats are defined:
E4M3 Format
- Structure: 1 sign bit, 4 exponent bits, 3 mantissa bits.
- Range: +/-448.
- Precision: Higher mantissa precision (3 bits) makes it suitable for forward pass activations and weights, where values are typically well-bounded.
E5M2 Format
- Structure: 1 sign bit, 5 exponent bits, 2 mantissa bits.
- Range: +/-57344.
- Precision: Wider dynamic range (5 exponent bits) at the cost of mantissa precision, making it suitable for backward pass gradients, which exhibit larger variance.
Scaling Factors
Because FP8 has a significantly reduced dynamic range compared to FP16/BF16, scaling factors are essential to map the tensor value distribution into the representable FP8 range. The scaling factor is computed as:
scale = FP8_MAX / (amax * 2^margin)
where:
FP8_MAXis the maximum representable value in the FP8 format (448 for E4M3, 57344 for E5M2).amaxis the absolute maximum value observed in the tensor (computed from history or current iteration).marginis an optional safety margin to prevent overflow.
The choice of how amax is computed defines the scaling strategy:
- Delayed Scaling: Uses a history of amax values from previous iterations.
- Current Scaling: Computes amax from the current tensor in the current iteration.
HYBRID Mode
The HYBRID format mode (the default in TransformerEngine) uses E4M3 for the forward pass and E5M2 for the backward pass. This combination balances precision and dynamic range across both passes of training, and is the recommended configuration for most workloads.
Related Pages
- Implementation:NVIDIA_TransformerEngine_TE_Autocast -- The context manager that enables FP8 execution.
- Principle:NVIDIA_TransformerEngine_FP8_Delayed_Scaling -- Delayed scaling strategy for computing FP8 scaling factors.
- Principle:NVIDIA_TransformerEngine_FP8_Current_Scaling -- Current scaling strategy for computing FP8 scaling factors.