Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Debug Fake Quant

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Debug, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Debug feature that performs fake quantization by casting tensors to FP8 and back, running the GEMM in high precision to emulate quantization effects without actual low-precision computation.

Description

FakeQuant disables real FP8 GEMM and instead applies a quantize-then-dequantize operation (fake quantization) to selected tensors. This emulates the numerical effects of FP8 casting while keeping all GEMM operations in high precision (BF16/FP32). It supports per-tensor current scaling (not delayed scaling) and four FP8 formats: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2. The fake_quantize() helper function computes the amax-based scale factor and performs the round-trip cast.

Usage

Enable via YAML config specifying which GEMMs and tensors to fake-quantize, along with the desired quant_format. Useful for analyzing the impact of quantization noise on training without the complexity of actual FP8 GEMM execution.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/debug/features/fake_quant.py
Lines
1--180

Signature

def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None) -> torch.Tensor: ...

@Registry.register_feature(namespace="transformer_engine")
class FakeQuant(TEConfigAPIMapper):
    def fp8_gemm_enabled(self, config, layer_name, gemm, iteration) -> Tuple[bool, None]: ...
    def modify_tensor_enabled(self, config, layer_name, tensor_name, gemm, iteration) -> Tuple[bool, int]: ...
    def modify_tensor(self, config, layer_name, gemm, tensor_name, tensor, iteration, default_quantizer, out=None, dtype=None) -> Optional[torch.Tensor]: ...

Import

from transformer_engine.debug.features.fake_quant import FakeQuant, fake_quantize

I/O Contract

Inputs

Name Type Required Description
config Dict Yes Must contain quant_format and GEMM/tensor specification
tensor torch.Tensor Yes High-precision GPU tensor (float32, float16, or bfloat16)
quant_format str Yes One of FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2

Outputs

Name Type Description
result torch.Tensor Fake-quantized tensor (same dtype as input, with quantization noise)

Usage Examples

# YAML configuration:
# example_fake_quant_fp8:
#   enabled: True
#   layers:
#     layer_types: [transformer_layer.layernorm_mlp.fc1]
#   transformer_engine:
#     FakeQuant:
#       enabled: True
#       quant_format: FP8E5M2
#       gemms_struct:
#         - gemm: fprop
#           tensors: [activation, weight]
#         - gemm: dgrad
#           tensors: [gradient]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment