Implementation:NVIDIA TransformerEngine Common Init
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Build_Infrastructure |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Framework-agnostic initialization module that handles discovery and loading of TransformerEngine's native shared libraries (core, torch, jax) and validates version consistency across TE packages.
Description
transformer_engine/common/__init__.py is the critical bootstrap module that bridges Python and C++ layers of TransformerEngine. It searches for .so files in multiple candidate directories (editable installs, source builds, wheel installs), loads them via importlib, and performs version matching assertions between the metapackage, core library, and framework extension packages.
Key capabilities:
- Shared library discovery: The
_get_shared_object_filefunction locateslibtransformer_engine.so(core) ortransformer_engine_{torch,jax}.soin the correct location based on the installation method. - Multi-install support: Handles editable installs, regular source builds, and PyPI wheel installs transparently by searching multiple directory paths.
- Version validation:
sanity_checks_for_pypi_installationverifies that all TE component versions are consistent when installed from PyPI. - Framework extension loading:
load_framework_extensiondynamically loads and returns the appropriate framework binding module.
Usage
This module is automatically invoked when import transformer_engine is executed. It should not typically be called directly by user code, but is essential for ensuring the correct native libraries are loaded regardless of how TransformerEngine was installed.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/__init__.py- Lines
- 1--384
Signature
def _get_shared_object_file(library: str) -> Path:
"""Path to shared object file for a Transformer Engine library."""
...
def get_te_core_package_info() -> Tuple[str, Path]:
"""Get the name and path of the TE core package."""
...
def load_framework_extension(framework: str):
"""Load a TE framework extension and return the module."""
...
def sanity_checks_for_pypi_installation() -> None:
"""Validate version consistency across TE packages."""
...
Import
from transformer_engine.common import load_framework_extension
from transformer_engine.common import get_te_core_package_info
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
library |
str |
Yes | Library identifier: "core", "torch", or "jax"
|
Outputs
| Name | Type | Description |
|---|---|---|
| shared object path | Path |
Absolute path to the located .so file
|
Usage Examples
# This module is imported automatically; direct usage is rare
from transformer_engine.common import load_framework_extension
# Load the PyTorch framework extension
torch_ext = load_framework_extension("pytorch")