Implementation:NVIDIA TransformerEngine MXFP8Tensor
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements MXFP8 (Microscaling FP8) tensors using the OCP MX block scaling format, where groups of 32 elements share a single E8M0 scale factor.
Description
MXFP8Quantizer validates that tensor dimensions are divisible by the block size (32), allocates rowwise/columnwise data and uint8 scale tensors with proper padding (rounded up to multiples of 128 and 4), and delegates quantization to tex.quantize. MXFP8Tensor stores separate rowwise and columnwise quantized data with swizzled scales optimized for GEMM operations. Enforces strict dimension alignment constraints for MX format compatibility. _ViewFunc and _ReshapeFunc preserve the quantized subclass through view and reshape operations.
Usage
Implements the OCP Microscaling (MX) FP8 format, an industry-standard low-precision format designed for hardware interoperability. The fixed 32-element block size with E8M0 scales targets specific hardware acceleration paths.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/tensor/mxfp8_tensor.py- Lines
- 1--997
Signature
class MXFP8Quantizer(Quantizer):
def __init__(self, fp8_dtype, ...): ...
def quantize(self, tensor, ...): ...
def set_usage(self, rowwise=False, columnwise=False): ...
class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
def __init__(self, *, rowwise_data, columnwise_data, ...): ...
def dequantize(self, dtype=None) -> torch.Tensor: ...
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs): ...
class _ViewFunc(torch.autograd.Function): ...
class _ReshapeFunc(torch.autograd.Function): ...
Import
from transformer_engine.pytorch.tensor.mxfp8_tensor import (
MXFP8Quantizer,
MXFP8Tensor,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize (dims must be divisible by 32) |
| fp8_dtype | torch.dtype |
Yes | Target FP8 dtype (e.g., torch.float8_e4m3fn) |
Outputs
| Name | Type | Description |
|---|---|---|
| mxfp8_tensor | MXFP8Tensor |
MXFP8-quantized tensor with E8M0 block scales |
Usage Examples
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import torch
quantizer = MXFP8Quantizer(fp8_dtype=torch.float8_e4m3fn)
quantizer.set_usage(rowwise=True, columnwise=True)
# Input dimensions must be divisible by 32 for MXFP8
input_tensor = torch.randn(256, 1024, device="cuda")
mxfp8_tensor = quantizer.quantize(input_tensor)
# Dequantize
output = mxfp8_tensor.dequantize(dtype=torch.bfloat16)