Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine MXFP8Tensor

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements MXFP8 (Microscaling FP8) tensors using the OCP MX block scaling format, where groups of 32 elements share a single E8M0 scale factor.

Description

MXFP8Quantizer validates that tensor dimensions are divisible by the block size (32), allocates rowwise/columnwise data and uint8 scale tensors with proper padding (rounded up to multiples of 128 and 4), and delegates quantization to tex.quantize. MXFP8Tensor stores separate rowwise and columnwise quantized data with swizzled scales optimized for GEMM operations. Enforces strict dimension alignment constraints for MX format compatibility. _ViewFunc and _ReshapeFunc preserve the quantized subclass through view and reshape operations.

Usage

Implements the OCP Microscaling (MX) FP8 format, an industry-standard low-precision format designed for hardware interoperability. The fixed 32-element block size with E8M0 scales targets specific hardware acceleration paths.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/tensor/mxfp8_tensor.py
Lines
1--997

Signature

class MXFP8Quantizer(Quantizer):
    def __init__(self, fp8_dtype, ...): ...
    def quantize(self, tensor, ...): ...
    def set_usage(self, rowwise=False, columnwise=False): ...

class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
    def __init__(self, *, rowwise_data, columnwise_data, ...): ...
    def dequantize(self, dtype=None) -> torch.Tensor: ...
    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs): ...

class _ViewFunc(torch.autograd.Function): ...
class _ReshapeFunc(torch.autograd.Function): ...

Import

from transformer_engine.pytorch.tensor.mxfp8_tensor import (
    MXFP8Quantizer,
    MXFP8Tensor,
)

I/O Contract

Inputs

Name Type Required Description
tensor torch.Tensor Yes High-precision tensor to quantize (dims must be divisible by 32)
fp8_dtype torch.dtype Yes Target FP8 dtype (e.g., torch.float8_e4m3fn)

Outputs

Name Type Description
mxfp8_tensor MXFP8Tensor MXFP8-quantized tensor with E8M0 block scales

Usage Examples

from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import torch

quantizer = MXFP8Quantizer(fp8_dtype=torch.float8_e4m3fn)
quantizer.set_usage(rowwise=True, columnwise=True)

# Input dimensions must be divisible by 32 for MXFP8
input_tensor = torch.randn(256, 1024, device="cuda")
mxfp8_tensor = quantizer.quantize(input_tensor)

# Dequantize
output = mxfp8_tensor.dequantize(dtype=torch.bfloat16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment