Implementation:NVIDIA TransformerEngine PyTorch Pybind
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
The main pybind11 module definition file that registers all C++ extension functions as Python-callable methods and initializes Python type references for custom tensor types.
Description
Initializes Python class references for Float8Tensor, MXFP8Tensor, Float8BlockwiseQTensor, NVFP4Tensor and their quantizer/storage counterparts by importing the corresponding Python modules and extracting PyTypeObject pointers. init_extension() calls all four init functions (float8, mxfp8, float8blockwise, nvfp4). The PYBIND11_MODULE macro registers hundreds of functions organized by category: quantize/dequantize, GEMM, activations (gelu, silu, relu, etc. and their derivatives), normalization (layernorm, rmsnorm fwd/bwd), softmax variants, attention (fused_attn_fwd/bwd), transpose, padding, permutation, RoPE, dropout, communication overlap classes, recipe management, multi-tensor optimizers, and version queries.
Usage
The single entry point that makes the entire C++ extension library accessible from Python. Defines the transformer_engine_torch Python module.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/pybind.cpp- Lines
- 1--518
Signature
namespace transformer_engine::pytorch {
void init_extension();
PYBIND11_MODULE(transformer_engine_torch, m) {
// Quantize/Dequantize
m.def("quantize", &quantize, ...);
m.def("dequantize", &dequantize, ...);
// GEMM
m.def("gemm", &gemm, ...);
m.def("te_atomic_gemm", &te_atomic_gemm, ...);
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, ...);
// Activations
m.def("gelu", &gelu, ...);
m.def("silu", &silu, ...);
// ... (hundreds more)
// Communication overlap classes
py::class_<CommOverlapHelper>(m, "CommOverlapHelper") ...;
py::class_<CommOverlap>(m, "CommOverlap") ...;
py::class_<CommOverlapP2P>(m, "CommOverlapP2P") ...;
}
}
Import
import transformer_engine_torch as tex
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| N/A | N/A | N/A | Module registration file -- no direct runtime inputs |
Outputs
| Name | Type | Description |
|---|---|---|
| transformer_engine_torch | module |
Python module exposing all C++ extension functions |
Usage Examples
import transformer_engine_torch as tex
# All C++ extensions are available through tex
output = tex.gelu(input_tensor, quantizer)
result = tex.gemm(A, B, D, bias, quantizer, workspace, ...)
norm_out = tex.layernorm_fwd(input, weight, bias, eps, quantizer, ...)