Environment:NVIDIA TransformerEngine Python PyTorch Requirements
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Deep_Learning |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
Python 3.10+ environment with PyTorch 2.1+ and optional JAX/Flax support for running NVIDIA TransformerEngine modules.
Description
This environment defines the Python-level runtime dependencies for TransformerEngine. The core requirement is Python 3.10+ with PyTorch 2.1+. For JAX users, JAX with Flax >= 0.7.1 is required. Additional packages include pydantic (for recipe dataclasses), einops (for tensor rearrangement in inference), and packaging (for version comparison). Optional dependencies like flash-attn, triton, and nvdlfw-inspect extend functionality.
Usage
Use this environment for any Python-based usage of TransformerEngine, including PyTorch training, JAX training, FP8 quantization, and inference. This is the mandatory Python runtime prerequisite for all TE Python APIs.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| Python | >= 3.10.0 | RuntimeError raised if older |
| PyTorch | >= 2.1 | For `transformer_engine.pytorch` |
| JAX | >= 0.5.0 | For `transformer_engine.jax` (optional) |
| Flax | >= 0.7.1 | For JAX Flax modules (optional) |
Dependencies
Python Packages (PyTorch)
- `torch` >= 2.1
- `einops` (tensor manipulation for KV cache inference)
- `onnxscript` (ONNX export support)
- `onnx` (ONNX model format)
- `packaging` (version comparison utilities)
- `pydantic` (recipe dataclass validation)
- `nvdlfw-inspect` (debug and inspection framework)
Python Packages (JAX, optional)
- `jax` >= 0.5.0
- `flax` >= 0.7.1
Optional Dependencies
- `flash-attn` >= 2.1.1 (FlashAttention-2 backend, >= 2.7.3 for Blackwell)
- `triton` (Triton kernel support for cross-entropy, permutation)
- `torchao` == 0.13 (for testing)
- `transformers` (for HuggingFace model integration examples)
Build Dependencies
- `setuptools` >= 61.0
- `wheel`
- `pybind11[global]`
- `pip`
Credentials
No credentials required for the Python runtime environment.
Quick Install
# Install TransformerEngine with PyTorch support
pip install transformer_engine[pytorch]
# Install with JAX support
pip install transformer_engine[jax]
# Install from source with specific framework
NVTE_FRAMEWORK=pytorch pip install .
# Install optional dependencies for full feature set
pip install flash-attn>=2.1.1 triton
# For HuggingFace integration examples
pip install transformers accelerate
Code Evidence
Python minimum version check from `build_tools/utils.py:23-37`:
def min_python_version() -> Tuple[int]:
"""Minimum supported Python version."""
return (3, 10, 0)
if sys.version_info < min_python_version():
raise RuntimeError(
f"Transformer Engine requires Python {min_python_version_str()} or newer, "
f"but found Python {platform.python_version()}."
)
PyTorch minimum version from `build_tools/pytorch.py:15-17`:
def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]
CUDA 12.0+ requirement during PyTorch build from `build_tools/pytorch.py:68`:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
Flash Attention version requirements from `transformer_engine/pytorch/attention/dot_product_attention/utils.py:114-118`:
class FlashAttentionUtils:
version_required = PkgVersion("2.1.1") # Pre-Blackwell minimum
version_required_blackwell = PkgVersion("2.7.3") # Blackwell (SM 10.0+)
max_version = PkgVersion("2.8.3") # Maximum supported
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| `RuntimeError: Transformer Engine requires Python 3.10.0 or newer` | Python version too old | Upgrade to Python 3.10+ |
| `RuntimeError: Transformer Engine requires CUDA 12.0 or newer` | PyTorch built with old CUDA | Install PyTorch with CUDA 12.x support |
| `ImportError: flash_attn not found` | FlashAttention not installed | `pip install flash-attn>=2.1.1` |
| `Unsupported cuda version X.Y` | CUDA major version not 12 or 13 | Use PyTorch built with CUDA 12.x or 13.x |
| `ImportError: No module named 'transformer_engine_torch'` | C++ extension not built | Rebuild: `pip install . --no-build-isolation` |
Compatibility Notes
- PyTorch 2.1+: Required minimum. Newer versions may enable additional features.
- JAX 0.5.0+: Changed FFI module location from `jax.extend.ffi` to `jax.ffi`. TE handles both.
- Flash Attention: Optional but strongly recommended. Version 2.7.3+ required for Blackwell GPUs. Maximum tested version is 2.8.3.
- Triton: Can use either standard `triton` or `pytorch-triton` package (controlled by `NVTE_USE_PYTORCH_TRITON=1`).
- pydantic: Used for recipe dataclass validation. Required at runtime for FP8 configuration.
Related Pages
- Implementation:NVIDIA_TransformerEngine_TE_Linear
- Implementation:NVIDIA_TransformerEngine_TE_LayerNorm
- Implementation:NVIDIA_TransformerEngine_TE_DotProductAttention
- Implementation:NVIDIA_TransformerEngine_TE_TransformerLayer
- Implementation:NVIDIA_TransformerEngine_TE_Autocast
- Implementation:NVIDIA_TransformerEngine_DelayedScaling_Recipe
- Implementation:NVIDIA_TransformerEngine_TELlamaDecoderLayer
- Implementation:NVIDIA_TransformerEngine_TEGemmaDecoderLayer
- Implementation:NVIDIA_TransformerEngine_InferenceParams