Implementation:NVIDIA TransformerEngine JAX XLA Quantization
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements XLA FFI handlers for quantization operations: fused dbias+quantize, grouped quantize, and dequantize, supporting FP8 and NVFP4 output formats with multiple scaling modes.
Description
GetDBiasQuantizeWorkspaceSizes uses dummy tensor wrappers with dummy pointers to query nvte_quantize_dbias for workspace requirements, setting up rowwise/columnwise output tensors based on the quantize layout. DBiasQuantizeFFI flattens the input to 2D using flatten_axis, configures output TensorWrapper with the appropriate scaling mode (delayed tensor, current tensor, MXFP8 block, or NVFP4), handles stochastic rounding and RHT (Randomized Hadamard Transform) for NVFP4 quantization, and dispatches to nvte_quantize_dbias or nvte_quantize depending on whether bias reduction is needed. GroupedQuantizeFFI handles quantizing multiple tensors packed contiguously. DequantizeFFI handles FP8-to-higher-precision conversion. All handlers are registered with XLA_FFI_DEFINE_HANDLER_SYMBOL.
This extension provides the quantization primitives that convert higher-precision tensors to FP8/FP4 formats with fused bias gradient computation, essential for FP8 training workflows in JAX.
Usage
This C++ extension is invoked internally by the Python-side QuantizePrimitive, DBiasQuantizePrimitive, and GroupedQuantizePrimitive in transformer_engine.jax.cpp_extensions.quantization. Users do not call these FFI handlers directly.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/csrc/extensions/quantization.cpp- Lines
- 1--498
Signature
namespace transformer_engine {
namespace jax {
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(
size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype,
DType scale_dtype, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout q_layout);
Error_Type DBiasQuantizeFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type dbias_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
bool stochastic_rounding, int rht_sign_mask);
Error_Type GroupedQuantizeFFI(...);
Error_Type DequantizeFFI(...);
} // namespace jax
} // namespace transformer_engine
Import
#include "transformer_engine/cast.h"
#include "transformer_engine/hadamard_transform.h"
#include "transformer_engine/recipe.h"
#include "../extensions.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_buf | Buffer_Type |
Yes | Input tensor buffer to quantize |
| scale_buf | Buffer_Type |
Yes | Scale factor buffer |
| amax_buf | Buffer_Type |
Yes | Amax buffer for delayed scaling |
| scaling_mode | JAXX_Scaling_Mode |
Yes | Quantization scaling mode |
| quantize_layout | JAXX_Quantize_Layout |
Yes | Output quantization layout |
| is_dbias | bool |
Yes | Whether to compute bias gradient |
| stochastic_rounding | bool |
No | Whether to use stochastic rounding (NVFP4) |
| rht_sign_mask | int |
No | RHT sign bitmask for NVFP4 quantization |
Outputs
| Name | Type | Description |
|---|---|---|
| output_buf | Result_Type |
Quantized output (rowwise) |
| colwise_output_buf | Result_Type |
Column-wise quantized output |
| dbias_buf | Result_Type |
Bias gradient (if is_dbias) |
| scale_inv_buf | Result_Type |
Inverse scale factor |
| updated_amax_buf | Result_Type |
Updated amax value |
Usage Examples
// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
// from transformer_engine.jax.cpp_extensions.quantization import quantize
// quantized = quantize(tensor, fp8_quantizer)