Implementation:NVIDIA TransformerEngine QuantizedTensor
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Defines the pure-Python abstract base classes for the quantized tensor type system: QuantizedTensorStorage, Quantizer, and QuantizedTensor with autograd support.
Description
QuantizedTensorStorage is the base for all tensor storage classes, providing the interface for rowwise/columnwise usage management, save/restore for autograd, and in-place quantization. Quantizer is an abstract base class defining quantize(), update_quantized(), and set_usage() methods. QuantizedTensor extends both QuantizedTensorStorage and torch.Tensor (via __torch_dispatch__) to behave as a regular PyTorch tensor while holding quantized data internally. The prepare_for_saving/restore_from_saved module-level functions handle serialization of mixed tensor/storage objects through autograd's save_for_backward.
Usage
Foundation of TE's quantized tensor type system. All concrete quantized tensor types (Float8Tensor, MXFP8Tensor, Float8BlockwiseQTensor, NVFP4Tensor) inherit from these base classes.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/quantized_tensor.py- Lines
- 1--600
Signature
class QuantizedTensorStorage:
def get_metadata(self) -> dict: ...
def prepare_for_saving(self) -> list: ...
def restore_from_saved(self, tensors) -> None: ...
def clear(self) -> None: ...
class Quantizer(abc.ABC):
@abc.abstractmethod
def quantize(self, tensor, ...): ...
def update_quantized(self, tensor, ...): ...
def set_usage(self, rowwise=False, columnwise=False): ...
class QuantizedTensor(QuantizedTensorStorage, torch.Tensor):
def dequantize(self, dtype=None) -> torch.Tensor: ...
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs): ...
def prepare_for_saving(*tensors): ...
def restore_from_saved(*saved): ...
Import
from transformer_engine.pytorch.quantized_tensor import (
QuantizedTensorStorage,
Quantizer,
QuantizedTensor,
prepare_for_saving,
restore_from_saved,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor |
Yes | High-precision tensor to quantize |
| dtype | torch.dtype |
No | Target dtype for dequantization |
Outputs
| Name | Type | Description |
|---|---|---|
| quantized_tensor | QuantizedTensor |
Tensor holding quantized data with dequantize() method |
| dequantized | torch.Tensor |
High-precision tensor reconstructed from quantized data |
Usage Examples
from transformer_engine.pytorch.quantized_tensor import QuantizedTensor
# QuantizedTensor behaves like a regular PyTorch tensor
fp8_tensor = quantizer.quantize(input_tensor)
assert isinstance(fp8_tensor, QuantizedTensor)
# Dequantize back to high precision
output = fp8_tensor.dequantize(dtype=torch.float32)
# Works with autograd save_for_backward
from transformer_engine.pytorch.quantized_tensor import prepare_for_saving, restore_from_saved
saved = prepare_for_saving(fp8_tensor, regular_tensor)
restored = restore_from_saved(*saved)