Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Float8Tensor

From Leeroopedia
Revision as of 15:57, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Float8Tensor.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements the primary FP8 tensor class with per-tensor scaling, supporting both delayed scaling (using amax history) and current scaling strategies for FP8 training.

Description

Float8Quantizer uses a pre-computed scale and tracks amax for delayed scaling updates managed externally by FP8GlobalStateManager. Float8CurrentScalingQuantizer computes scales on-the-fly from the current tensor data. Float8Tensor is a PyTorch tensor subclass that stores FP8 data as uint8, maintains a scale inverse, optionally caches a transpose for columnwise GEMM, and overrides __torch_dispatch__ to handle FSDP2 operations (slice, copy, view, etc.) while preserving the FP8 subclass. The foundational FP8 tensor type in TransformerEngine, with the FSDP2 dispatch handling critical for distributed training compatibility.

Usage

The primary FP8 tensor type. Per-tensor delayed scaling was the original FP8 training approach and remains widely used. Used throughout the framework for quantized weights, activations, and gradients.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/tensor/float8_tensor.py
Lines
1--1098

Signature

class Float8Quantizer(Quantizer):
    def __init__(self, scale, amax, fp8_dtype, ...): ...
    def quantize(self, tensor, ...): ...

class Float8CurrentScalingQuantizer(Quantizer):
    def __init__(self, fp8_dtype, ...): ...
    def quantize(self, tensor, ...): ...

class Float8Tensor(Float8TensorStorage, QuantizedTensor):
    def __init__(self, *, data, fp8_dtype, fp8_scale_inv, dtype, ...): ...
    def dequantize(self, dtype=None) -> torch.Tensor: ...
    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs): ...

class _ViewFunc(torch.autograd.Function): ...
class _ReshapeFunc(torch.autograd.Function): ...

Import

from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
    Float8Tensor,
)

I/O Contract

Inputs

Name Type Required Description
tensor torch.Tensor Yes High-precision tensor to quantize to FP8
scale torch.Tensor No Pre-computed scale factor (delayed scaling)
amax torch.Tensor No Absolute maximum value tracking tensor
fp8_dtype torch.dtype Yes Target FP8 dtype (e.g., torch.float8_e4m3fn, torch.float8_e5m2)

Outputs

Name Type Description
fp8_tensor Float8Tensor Per-tensor scaled FP8 tensor with scale_inv

Usage Examples

from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
)
import torch

# Delayed scaling quantizer
quantizer = Float8Quantizer(
    scale=scale_tensor,
    amax=amax_tensor,
    fp8_dtype=torch.float8_e4m3fn,
)
fp8_tensor = quantizer.quantize(input_tensor)

# Current scaling quantizer (computes scale from data)
cs_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=torch.float8_e4m3fn)
fp8_tensor_cs = cs_quantizer.quantize(input_tensor)

# Dequantize
output = fp8_tensor.dequantize(dtype=torch.bfloat16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment