Implementation:NVIDIA TransformerEngine ONNX Extensions
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Defines custom torch.ops extensions and their ONNX symbolic functions to enable ONNX export of TransformerEngine models with FP8 and MXFP8 quantization for TensorRT inference.
Description
Registers custom PyTorch operators (tex::gemm_inf, tex::fp8_quantize, tex::fp8_dequantize, tex::mxfp8_quantize, tex::mxfp8_dequantize, etc.) using @torch.library.custom_op with both real implementations and fake tensor implementations. Each operator has a corresponding ONNX symbolic function using onnxscript that maps to TensorRT custom ops (TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, etc.) since standard ONNX does not natively support FP8 quantization.
Usage
Required when exporting TE models using FP8 to ONNX format for deployment with TensorRT inference. Without these extensions, FP8 operations cannot be represented in ONNX graphs.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/onnx_extensions.py- Lines
- 1--404
Signature
def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: ...
@torch.library.custom_op("tex::gemm_inf", ...)
def torch_onnx_gemm_inf_op(weight, inp, bias): ...
@torch.library.custom_op("tex::fp8_quantize", ...)
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: ...
@torch.library.custom_op("tex::fp8_dequantize", ...)
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: ...
@torch.library.custom_op("tex::cs_fp8_quantize", ...)
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...
@torch.library.custom_op("tex::mxfp8_quantize", ...)
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...
Import
from transformer_engine.pytorch.onnx_extensions import (
onnx_gemm,
onnx_quantize_fp8_op,
onnx_dequantize_fp8_op,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor | torch.Tensor |
Yes | Input tensor for quantization/dequantization |
| scale | float |
No | Scale factor for FP8 quantization |
| scale_inv | torch.Tensor |
No | Inverse scale for FP8 dequantization |
| weight | torch.Tensor |
No | Weight tensor for GEMM |
| inp | torch.Tensor |
No | Input tensor for GEMM |
Outputs
| Name | Type | Description |
|---|---|---|
| quantized | torch.Tensor |
FP8 quantized tensor (for quantize ops) |
| dequantized | torch.Tensor |
High-precision reconstructed tensor (for dequantize ops) |
| gemm_output | torch.Tensor |
Result of the GEMM operation |
Usage Examples
import torch
from transformer_engine.pytorch.export import onnx_export
# Export a TE model to ONNX with FP8 support
# The ONNX extensions are registered automatically
onnx_export(model, sample_input, "model.onnx")