Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Ext Cast

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements quantize, dequantize, multi-tensor quantize, and stochastic quantization operations that convert between high-precision and low-precision (FP8/MXFP8/NVFP4) tensor formats.

Description

quantize converts a PyTorch tensor using the provided quantizer (handling existing amax for current scaling, output tensor reuse, and no-op flags). dequantize converts quantized tensors back via nvte_dequantize. multi_tensor_quantize_impl batches multiple tensor quantizations using a fused kernel for FP8 delayed scaling (nvte_multi_quantize), falling back to individual quantization otherwise. Also provides stochastic_quantize with RNG state management for stochastic rounding during FP8 quantization, and quantize_with_amax_update which computes amax and quantizes in a single step. StochasticRngStateResources manages CUDA generator states for graph-safe stochastic rounding.

Usage

Central quantization implementation -- all FP8 cast operations flow through this file, making it essential for TE's mixed-precision training pipeline.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/cast.cpp
Lines
1--1249

Signature

namespace transformer_engine::pytorch {

py::object quantize(at::Tensor input, py::handle quantizer,
                    py::handle output_tensor = py::none(),
                    bool noop = false);

at::Tensor dequantize(at::Tensor input, py::handle tensor_obj,
                      at::ScalarType otype);

py::list multi_tensor_quantize_impl(
    std::vector<at::Tensor> input_list,
    std::vector<py::handle> quantizer_list, ...);

py::object stochastic_quantize(at::Tensor input, py::handle quantizer, ...);
py::object quantize_with_amax_update(at::Tensor input, py::handle quantizer);

}

Import

#include "../extensions.h"
#include "common.h"

I/O Contract

Inputs

Name Type Required Description
input at::Tensor Yes High-precision input tensor
quantizer py::handle Yes Quantizer defining the target format
output_tensor py::handle No Pre-allocated output tensor to reuse
noop bool No If true, skip quantization (passthrough)

Outputs

Name Type Description
quantized py::object Quantized tensor in the target low-precision format
dequantized at::Tensor High-precision tensor reconstructed from quantized data

Usage Examples

import transformer_engine_torch as tex

# Quantize a tensor to FP8
fp8_tensor = tex.quantize(input_tensor, fp8_quantizer)

# Dequantize back to high precision
output = tex.dequantize(fp8_tensor, tensor_obj, torch.float32)

# Batch quantize multiple tensors
quantized_list = tex.multi_tensor_quantize(input_list, quantizer_list)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment