Implementation:NVIDIA TransformerEngine Debug Fake Quant
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Debug, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Debug feature that performs fake quantization by casting tensors to FP8 and back, running the GEMM in high precision to emulate quantization effects without actual low-precision computation.
Description
FakeQuant disables real FP8 GEMM and instead applies a quantize-then-dequantize operation (fake quantization) to selected tensors. This emulates the numerical effects of FP8 casting while keeping all GEMM operations in high precision (BF16/FP32). It supports per-tensor current scaling (not delayed scaling) and four FP8 formats: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2. The fake_quantize() helper function computes the amax-based scale factor and performs the round-trip cast.
Usage
Enable via YAML config specifying which GEMMs and tensors to fake-quantize, along with the desired quant_format. Useful for analyzing the impact of quantization noise on training without the complexity of actual FP8 GEMM execution.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/debug/features/fake_quant.py- Lines
- 1--180
Signature
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None) -> torch.Tensor: ...
@Registry.register_feature(namespace="transformer_engine")
class FakeQuant(TEConfigAPIMapper):
def fp8_gemm_enabled(self, config, layer_name, gemm, iteration) -> Tuple[bool, None]: ...
def modify_tensor_enabled(self, config, layer_name, tensor_name, gemm, iteration) -> Tuple[bool, int]: ...
def modify_tensor(self, config, layer_name, gemm, tensor_name, tensor, iteration, default_quantizer, out=None, dtype=None) -> Optional[torch.Tensor]: ...
Import
from transformer_engine.debug.features.fake_quant import FakeQuant, fake_quantize
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | Dict | Yes | Must contain quant_format and GEMM/tensor specification
|
| tensor | torch.Tensor | Yes | High-precision GPU tensor (float32, float16, or bfloat16) |
| quant_format | str | Yes | One of FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2
|
Outputs
| Name | Type | Description |
|---|---|---|
| result | torch.Tensor | Fake-quantized tensor (same dtype as input, with quantization noise) |
Usage Examples
# YAML configuration:
# example_fake_quant_fp8:
# enabled: True
# layers:
# layer_types: [transformer_layer.layernorm_mlp.fc1]
# transformer_engine:
# FakeQuant:
# enabled: True
# quant_format: FP8E5M2
# gemms_struct:
# - gemm: fprop
# tensors: [activation, weight]
# - gemm: dgrad
# tensors: [gradient]