Implementation:NVIDIA TransformerEngine Float8Tensor
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements the primary FP8 tensor class with per-tensor scaling, supporting both delayed scaling (using amax history) and current scaling strategies for FP8 training.
Description
Float8Quantizer uses a pre-computed scale and tracks amax for delayed scaling updates managed externally by FP8GlobalStateManager. Float8CurrentScalingQuantizer computes scales on-the-fly from the current tensor data. Float8Tensor is a PyTorch tensor subclass that stores FP8 data as uint8, maintains a scale inverse, optionally caches a transpose for columnwise GEMM, and overrides __torch_dispatch__ to handle FSDP2 operations (slice, copy, view, etc.) while preserving the FP8 subclass. The foundational FP8 tensor type in TransformerEngine, with the FSDP2 dispatch handling critical for distributed training compatibility.
Usage
The primary FP8 tensor type. Per-tensor delayed scaling was the original FP8 training approach and remains widely used. Used throughout the framework for quantized weights, activations, and gradients.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/tensor/float8_tensor.py- Lines
- 1--1098
Signature
class Float8Quantizer(Quantizer):
def __init__(self, scale, amax, fp8_dtype, ...): ...
def quantize(self, tensor, ...): ...
class Float8CurrentScalingQuantizer(Quantizer):
def __init__(self, fp8_dtype, ...): ...
def quantize(self, tensor, ...): ...
class Float8Tensor(Float8TensorStorage, QuantizedTensor):
def __init__(self, *, data, fp8_dtype, fp8_scale_inv, dtype, ...): ...
def dequantize(self, dtype=None) -> torch.Tensor: ...
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs): ...
class _ViewFunc(torch.autograd.Function): ...
class _ReshapeFunc(torch.autograd.Function): ...
Import
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
Float8Tensor,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize to FP8 |
| scale | torch.Tensor |
No | Pre-computed scale factor (delayed scaling) |
| amax | torch.Tensor |
No | Absolute maximum value tracking tensor |
| fp8_dtype | torch.dtype |
Yes | Target FP8 dtype (e.g., torch.float8_e4m3fn, torch.float8_e5m2) |
Outputs
| Name | Type | Description |
|---|---|---|
| fp8_tensor | Float8Tensor |
Per-tensor scaled FP8 tensor with scale_inv |
Usage Examples
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
import torch
# Delayed scaling quantizer
quantizer = Float8Quantizer(
scale=scale_tensor,
amax=amax_tensor,
fp8_dtype=torch.float8_e4m3fn,
)
fp8_tensor = quantizer.quantize(input_tensor)
# Current scaling quantizer (computes scale from data)
cs_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=torch.float8_e4m3fn)
fp8_tensor_cs = cs_quantizer.quantize(input_tensor)
# Dequantize
output = fp8_tensor.dequantize(dtype=torch.bfloat16)