Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine ONNX Extensions

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Defines custom torch.ops extensions and their ONNX symbolic functions to enable ONNX export of TransformerEngine models with FP8 and MXFP8 quantization for TensorRT inference.

Description

Registers custom PyTorch operators (tex::gemm_inf, tex::fp8_quantize, tex::fp8_dequantize, tex::mxfp8_quantize, tex::mxfp8_dequantize, etc.) using @torch.library.custom_op with both real implementations and fake tensor implementations. Each operator has a corresponding ONNX symbolic function using onnxscript that maps to TensorRT custom ops (TRT_FP8QuantizeLinear, TRT_FP8DequantizeLinear, etc.) since standard ONNX does not natively support FP8 quantization.

Usage

Required when exporting TE models using FP8 to ONNX format for deployment with TensorRT inference. Without these extensions, FP8 operations cannot be represented in ONNX graphs.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/onnx_extensions.py
Lines
1--404

Signature

def onnx_gemm(weight: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: ...

@torch.library.custom_op("tex::gemm_inf", ...)
def torch_onnx_gemm_inf_op(weight, inp, bias): ...

@torch.library.custom_op("tex::fp8_quantize", ...)
def onnx_quantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: ...

@torch.library.custom_op("tex::fp8_dequantize", ...)
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: ...

@torch.library.custom_op("tex::cs_fp8_quantize", ...)
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...

@torch.library.custom_op("tex::mxfp8_quantize", ...)
def onnx_quantize_mxfp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ...

Import

from transformer_engine.pytorch.onnx_extensions import (
    onnx_gemm,
    onnx_quantize_fp8_op,
    onnx_dequantize_fp8_op,
)

I/O Contract

Inputs

Name Type Required Description
tensor torch.Tensor Yes Input tensor for quantization/dequantization
scale float No Scale factor for FP8 quantization
scale_inv torch.Tensor No Inverse scale for FP8 dequantization
weight torch.Tensor No Weight tensor for GEMM
inp torch.Tensor No Input tensor for GEMM

Outputs

Name Type Description
quantized torch.Tensor FP8 quantized tensor (for quantize ops)
dequantized torch.Tensor High-precision reconstructed tensor (for dequantize ops)
gemm_output torch.Tensor Result of the GEMM operation

Usage Examples

import torch
from transformer_engine.pytorch.export import onnx_export

# Export a TE model to ONNX with FP8 support
# The ONNX extensions are registered automatically
onnx_export(model, sample_input, "model.onnx")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment