Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Custom NVFP4

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Reference implementation of the NVFP4 (4-bit floating point) quantization recipe with Random Hadamard Transform (RHT) and 2D block quantization for weights.

Description

Provides pure-Python FP4 E2M1 format casting (cast_to_fp4x2/cast_from_fp4x2) that packs two 4-bit values into one byte. NVFP4QuantizerRef implements quantization with configurable tile shapes (1x16 for activations, 16x16 for weights), two-level scaling (per-block FP8 E8M0 scale + global FP8 E4M3 scale), optional RHT pre-processing, and qgemm() for dequantize-then-matmul execution. NVFP4TensorRef stores the quantized data, block scales, global scale, and transposed variants. Helper functions cast_to_e8 and cast_to_e4m3 handle the scale-factor dtype conversions. high_precision_gemm_ref provides a reference GEMM implementation.

Usage

Reference implementation for NVFP4 quantization targeting Blackwell GPUs. Demonstrates the complete NVFP4 flow including 2D tiling, two-level scale hierarchy, and RHT. Serves as a correctness baseline for the optimized CUDA implementation.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py
Lines
1--887

Signature

def nvfp4_ref_rht_2d_quantizer_factory(role): ...

def cast_to_fp4x2(x) -> torch.Tensor: ...
def cast_from_fp4x2(x, dq_dtype) -> torch.Tensor: ...
def cast_to_e8(decode_scale) -> torch.Tensor: ...
def cast_to_e4m3(decode_scale, global_amax) -> torch.Tensor: ...
def high_precision_gemm_ref(a, b, ...): ...

class NVFP4TensorRef(QuantizedTensorStorage):
    def __init__(self, data, block_scales, global_scale, ...): ...
    def prepare_for_saving(self) -> list: ...

class NVFP4QuantizerRef(Quantizer):
    def __init__(self, tile_shape, fp8_dtype, with_rht=False, ...): ...
    def quantize(self, tensor, ...): ...
    def qgemm(self, a, b, ...): ...

Import

from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
    nvfp4_ref_rht_2d_quantizer_factory,
    NVFP4QuantizerRef,
    NVFP4TensorRef,
    cast_to_fp4x2,
    cast_from_fp4x2,
)

I/O Contract

Inputs

Name Type Required Description
role str Yes Role identifier determining tile shape and dtype selection
tensor torch.Tensor Yes High-precision tensor to quantize to FP4
tile_shape tuple No Block tile dimensions (e.g., (1, 16) for activations, (16, 16) for weights)
with_rht bool No Whether to apply Random Hadamard Transform

Outputs

Name Type Description
quantizer NVFP4QuantizerRef Reference FP4 quantizer instance
quantized_storage NVFP4TensorRef Storage holding packed FP4 data with block and global scales

Usage Examples

from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import (
    nvfp4_ref_rht_2d_quantizer_factory,
    cast_to_fp4x2,
    cast_from_fp4x2,
)

# Create a reference NVFP4 quantizer with RHT
quantizer = nvfp4_ref_rht_2d_quantizer_factory("input")
fp4_ref = quantizer.quantize(input_tensor)

# Low-level FP4 packing (2 values per byte)
packed = cast_to_fp4x2(float_tensor)
unpacked = cast_from_fp4x2(packed, torch.bfloat16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment