Implementation:NVIDIA TransformerEngine Custom NVFP4
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Reference implementation of the NVFP4 (4-bit floating point) quantization recipe with Random Hadamard Transform (RHT) and 2D block quantization for weights.
Description
Provides pure-Python FP4 E2M1 format casting (cast_to_fp4x2/cast_from_fp4x2) that packs two 4-bit values into one byte. NVFP4QuantizerRef implements quantization with configurable tile shapes (1x16 for activations, 16x16 for weights), two-level scaling (per-block FP8 E8M0 scale + global FP8 E4M3 scale), optional RHT pre-processing, and qgemm() for dequantize-then-matmul execution. NVFP4TensorRef stores the quantized data, block scales, global scale, and transposed variants. Helper functions cast_to_e8 and cast_to_e4m3 handle the scale-factor dtype conversions. high_precision_gemm_ref provides a reference GEMM implementation.
Usage
Reference implementation for NVFP4 quantization targeting Blackwell GPUs. Demonstrates the complete NVFP4 flow including 2D tiling, two-level scale hierarchy, and RHT. Serves as a correctness baseline for the optimized CUDA implementation.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py- Lines
- 1--887
Signature
def nvfp4_ref_rht_2d_quantizer_factory(role): ...
def cast_to_fp4x2(x) -> torch.Tensor: ...
def cast_from_fp4x2(x, dq_dtype) -> torch.Tensor: ...
def cast_to_e8(decode_scale) -> torch.Tensor: ...
def cast_to_e4m3(decode_scale, global_amax) -> torch.Tensor: ...
def high_precision_gemm_ref(a, b, ...): ...
class NVFP4TensorRef(QuantizedTensorStorage):
def __init__(self, data, block_scales, global_scale, ...): ...
def prepare_for_saving(self) -> list: ...
class NVFP4QuantizerRef(Quantizer):
def __init__(self, tile_shape, fp8_dtype, with_rht=False, ...): ...
def quantize(self, tensor, ...): ...
def qgemm(self, a, b, ...): ...
Import
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
nvfp4_ref_rht_2d_quantizer_factory,
NVFP4QuantizerRef,
NVFP4TensorRef,
cast_to_fp4x2,
cast_from_fp4x2,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| role | str |
Yes | Role identifier determining tile shape and dtype selection |
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize to FP4 |
| tile_shape | tuple |
No | Block tile dimensions (e.g., (1, 16) for activations, (16, 16) for weights) |
| with_rht | bool |
No | Whether to apply Random Hadamard Transform |
Outputs
| Name | Type | Description |
|---|---|---|
| quantizer | NVFP4QuantizerRef |
Reference FP4 quantizer instance |
| quantized_storage | NVFP4TensorRef |
Storage holding packed FP4 data with block and global scales |
Usage Examples
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
nvfp4_ref_rht_2d_quantizer_factory,
cast_to_fp4x2,
cast_from_fp4x2,
)
# Create a reference NVFP4 quantizer with RHT
quantizer = nvfp4_ref_rht_2d_quantizer_factory("input")
fp4_ref = quantizer.quantize(input_tensor)
# Low-level FP4 packing (2 values per byte)
packed = cast_to_fp4x2(float_tensor)
unpacked = cast_from_fp4x2(packed, torch.bfloat16)