Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Pybind

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch
Last Updated 2026-02-07 14:00 GMT

Overview

The main pybind11 module definition file that registers all C++ extension functions as Python-callable methods and initializes Python type references for custom tensor types.

Description

Initializes Python class references for Float8Tensor, MXFP8Tensor, Float8BlockwiseQTensor, NVFP4Tensor and their quantizer/storage counterparts by importing the corresponding Python modules and extracting PyTypeObject pointers. init_extension() calls all four init functions (float8, mxfp8, float8blockwise, nvfp4). The PYBIND11_MODULE macro registers hundreds of functions organized by category: quantize/dequantize, GEMM, activations (gelu, silu, relu, etc. and their derivatives), normalization (layernorm, rmsnorm fwd/bwd), softmax variants, attention (fused_attn_fwd/bwd), transpose, padding, permutation, RoPE, dropout, communication overlap classes, recipe management, multi-tensor optimizers, and version queries.

Usage

The single entry point that makes the entire C++ extension library accessible from Python. Defines the transformer_engine_torch Python module.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/pybind.cpp
Lines
1--518

Signature

namespace transformer_engine::pytorch {

void init_extension();

PYBIND11_MODULE(transformer_engine_torch, m) {
    // Quantize/Dequantize
    m.def("quantize", &quantize, ...);
    m.def("dequantize", &dequantize, ...);

    // GEMM
    m.def("gemm", &gemm, ...);
    m.def("te_atomic_gemm", &te_atomic_gemm, ...);
    m.def("te_general_grouped_gemm", &te_general_grouped_gemm, ...);

    // Activations
    m.def("gelu", &gelu, ...);
    m.def("silu", &silu, ...);
    // ... (hundreds more)

    // Communication overlap classes
    py::class_<CommOverlapHelper>(m, "CommOverlapHelper") ...;
    py::class_<CommOverlap>(m, "CommOverlap") ...;
    py::class_<CommOverlapP2P>(m, "CommOverlapP2P") ...;
}

}

Import

import transformer_engine_torch as tex

I/O Contract

Inputs

Name Type Required Description
N/A N/A N/A Module registration file -- no direct runtime inputs

Outputs

Name Type Description
transformer_engine_torch module Python module exposing all C++ extension functions

Usage Examples

import transformer_engine_torch as tex

# All C++ extensions are available through tex
output = tex.gelu(input_tensor, quantizer)
result = tex.gemm(A, B, D, bias, quantizer, workspace, ...)
norm_out = tex.layernorm_fwd(input, weight, bias, eps, quantizer, ...)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment