Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX XLA Quantization

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements XLA FFI handlers for quantization operations: fused dbias+quantize, grouped quantize, and dequantize, supporting FP8 and NVFP4 output formats with multiple scaling modes.

Description

GetDBiasQuantizeWorkspaceSizes uses dummy tensor wrappers with dummy pointers to query nvte_quantize_dbias for workspace requirements, setting up rowwise/columnwise output tensors based on the quantize layout. DBiasQuantizeFFI flattens the input to 2D using flatten_axis, configures output TensorWrapper with the appropriate scaling mode (delayed tensor, current tensor, MXFP8 block, or NVFP4), handles stochastic rounding and RHT (Randomized Hadamard Transform) for NVFP4 quantization, and dispatches to nvte_quantize_dbias or nvte_quantize depending on whether bias reduction is needed. GroupedQuantizeFFI handles quantizing multiple tensors packed contiguously. DequantizeFFI handles FP8-to-higher-precision conversion. All handlers are registered with XLA_FFI_DEFINE_HANDLER_SYMBOL.

This extension provides the quantization primitives that convert higher-precision tensors to FP8/FP4 formats with fused bias gradient computation, essential for FP8 training workflows in JAX.

Usage

This C++ extension is invoked internally by the Python-side QuantizePrimitive, DBiasQuantizePrimitive, and GroupedQuantizePrimitive in transformer_engine.jax.cpp_extensions.quantization. Users do not call these FFI handlers directly.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/csrc/extensions/quantization.cpp
Lines
1--498

Signature

namespace transformer_engine {
namespace jax {

pybind11::tuple GetDBiasQuantizeWorkspaceSizes(
    size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype,
    DType scale_dtype, JAXX_Scaling_Mode scaling_mode,
    JAXX_Quantize_Layout q_layout);

Error_Type DBiasQuantizeFFI(
    cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
    Result_Type dbias_buf, Result_Type scale_inv_buf,
    Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
    Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
    JAXX_Quantize_Layout quantize_layout, bool is_dbias,
    bool stochastic_rounding, int rht_sign_mask);

Error_Type GroupedQuantizeFFI(...);
Error_Type DequantizeFFI(...);

} // namespace jax
} // namespace transformer_engine

Import

#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "../extensions.h"

I/O Contract

Inputs

Name Type Required Description
input_buf Buffer_Type Yes Input tensor buffer to quantize
scale_buf Buffer_Type Yes Scale factor buffer
amax_buf Buffer_Type Yes Amax buffer for delayed scaling
scaling_mode JAXX_Scaling_Mode Yes Quantization scaling mode
quantize_layout JAXX_Quantize_Layout Yes Output quantization layout
is_dbias bool Yes Whether to compute bias gradient
stochastic_rounding bool No Whether to use stochastic rounding (NVFP4)
rht_sign_mask int No RHT sign bitmask for NVFP4 quantization

Outputs

Name Type Description
output_buf Result_Type Quantized output (rowwise)
colwise_output_buf Result_Type Column-wise quantized output
dbias_buf Result_Type Bias gradient (if is_dbias)
scale_inv_buf Result_Type Inverse scale factor
updated_amax_buf Result_Type Updated amax value

Usage Examples

// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
//   from transformer_engine.jax.cpp_extensions.quantization import quantize
//   quantized = quantize(tensor, fp8_quantizer)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment