Implementation:NVIDIA TransformerEngine Float8BlockwiseTensor
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements FP8 tensors using blockwise (NxN tile) scaling, where each 128-element block has its own scale factor, enabling finer-grained quantization than per-tensor scaling.
Description
Float8BlockQuantizer is the builder class that computes scale shapes based on 1D or 2D block tiling, pads scales to multiples of 4 for GEMM compatibility, and delegates actual quantization to tex.quantize. Float8BlockwiseQTensor is the tensor subclass that stores separate rowwise and columnwise quantized data with per-block scale inverses. Supports both 1D (vectorwise) and 2D block scaling dimensions with power-of-2 scale enforcement. _ViewFunc and _ReshapeFunc are custom autograd functions to preserve the quantized subclass through view and reshape operations.
Usage
Provides the block-scaled FP8 format (current scaling) that offers better numerical accuracy than delayed per-tensor scaling by adapting scale factors to local data distributions. Key quantization strategy for Hopper+ GPUs.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py- Lines
- 1--780
Signature
class Float8BlockQuantizer(Quantizer):
def __init__(self, fp8_dtype, block_dims=None, ...): ...
def quantize(self, tensor, ...): ...
def set_usage(self, rowwise=False, columnwise=False): ...
class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
def dequantize(self, dtype=None) -> torch.Tensor: ...
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs): ...
class _ViewFunc(torch.autograd.Function): ...
class _ReshapeFunc(torch.autograd.Function): ...
Import
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize blockwise |
| fp8_dtype | torch.dtype |
Yes | Target FP8 dtype (e.g., torch.float8_e4m3fn) |
| block_dims | tuple |
No | Block dimensions for scaling (e.g., (1, 128) for 1D, (128, 128) for 2D) |
Outputs
| Name | Type | Description |
|---|---|---|
| quantized | Float8BlockwiseQTensor |
Block-scaled FP8 tensor with per-block scale inverses |
Usage Examples
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
import torch
quantizer = Float8BlockQuantizer(
fp8_dtype=torch.float8_e4m3fn,
block_dims=(1, 128),
)
quantizer.set_usage(rowwise=True, columnwise=True)
fp8_blockwise = quantizer.quantize(input_tensor)
# Dequantize back
output = fp8_blockwise.dequantize(dtype=torch.bfloat16)