Implementation:NVIDIA TransformerEngine PyTorch Ext Cast
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements quantize, dequantize, multi-tensor quantize, and stochastic quantization operations that convert between high-precision and low-precision (FP8/MXFP8/NVFP4) tensor formats.
Description
quantize converts a PyTorch tensor using the provided quantizer (handling existing amax for current scaling, output tensor reuse, and no-op flags). dequantize converts quantized tensors back via nvte_dequantize. multi_tensor_quantize_impl batches multiple tensor quantizations using a fused kernel for FP8 delayed scaling (nvte_multi_quantize), falling back to individual quantization otherwise. Also provides stochastic_quantize with RNG state management for stochastic rounding during FP8 quantization, and quantize_with_amax_update which computes amax and quantizes in a single step. StochasticRngStateResources manages CUDA generator states for graph-safe stochastic rounding.
Usage
Central quantization implementation -- all FP8 cast operations flow through this file, making it essential for TE's mixed-precision training pipeline.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/cast.cpp- Lines
- 1--1249
Signature
namespace transformer_engine::pytorch {
py::object quantize(at::Tensor input, py::handle quantizer,
py::handle output_tensor = py::none(),
bool noop = false);
at::Tensor dequantize(at::Tensor input, py::handle tensor_obj,
at::ScalarType otype);
py::list multi_tensor_quantize_impl(
std::vector<at::Tensor> input_list,
std::vector<py::handle> quantizer_list, ...);
py::object stochastic_quantize(at::Tensor input, py::handle quantizer, ...);
py::object quantize_with_amax_update(at::Tensor input, py::handle quantizer);
}
Import
#include "../extensions.h"
#include "common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | at::Tensor |
Yes | High-precision input tensor |
| quantizer | py::handle |
Yes | Quantizer defining the target format |
| output_tensor | py::handle |
No | Pre-allocated output tensor to reuse |
| noop | bool |
No | If true, skip quantization (passthrough) |
Outputs
| Name | Type | Description |
|---|---|---|
| quantized | py::object |
Quantized tensor in the target low-precision format |
| dequantized | at::Tensor |
High-precision tensor reconstructed from quantized data |
Usage Examples
import transformer_engine_torch as tex
# Quantize a tensor to FP8
fp8_tensor = tex.quantize(input_tensor, fp8_quantizer)
# Dequantize back to high precision
output = tex.dequantize(fp8_tensor, tensor_obj, torch.float32)
# Batch quantize multiple tensors
quantized_list = tex.multi_tensor_quantize(input_list, quantizer_list)