Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Environment:NVIDIA TransformerEngine Python PyTorch Requirements

From Leeroopedia


Knowledge Sources
Domains Infrastructure, Deep_Learning
Last Updated 2026-02-07 21:00 GMT

Overview

Python 3.10+ environment with PyTorch 2.1+ and optional JAX/Flax support for running NVIDIA TransformerEngine modules.

Description

This environment defines the Python-level runtime dependencies for TransformerEngine. The core requirement is Python 3.10+ with PyTorch 2.1+. For JAX users, JAX with Flax >= 0.7.1 is required. Additional packages include pydantic (for recipe dataclasses), einops (for tensor rearrangement in inference), and packaging (for version comparison). Optional dependencies like flash-attn, triton, and nvdlfw-inspect extend functionality.

Usage

Use this environment for any Python-based usage of TransformerEngine, including PyTorch training, JAX training, FP8 quantization, and inference. This is the mandatory Python runtime prerequisite for all TE Python APIs.

System Requirements

Category Requirement Notes
Python >= 3.10.0 RuntimeError raised if older
PyTorch >= 2.1 For `transformer_engine.pytorch`
JAX >= 0.5.0 For `transformer_engine.jax` (optional)
Flax >= 0.7.1 For JAX Flax modules (optional)

Dependencies

Python Packages (PyTorch)

  • `torch` >= 2.1
  • `einops` (tensor manipulation for KV cache inference)
  • `onnxscript` (ONNX export support)
  • `onnx` (ONNX model format)
  • `packaging` (version comparison utilities)
  • `pydantic` (recipe dataclass validation)
  • `nvdlfw-inspect` (debug and inspection framework)

Python Packages (JAX, optional)

  • `jax` >= 0.5.0
  • `flax` >= 0.7.1

Optional Dependencies

  • `flash-attn` >= 2.1.1 (FlashAttention-2 backend, >= 2.7.3 for Blackwell)
  • `triton` (Triton kernel support for cross-entropy, permutation)
  • `torchao` == 0.13 (for testing)
  • `transformers` (for HuggingFace model integration examples)

Build Dependencies

  • `setuptools` >= 61.0
  • `wheel`
  • `pybind11[global]`
  • `pip`

Credentials

No credentials required for the Python runtime environment.

Quick Install

# Install TransformerEngine with PyTorch support
pip install transformer_engine[pytorch]

# Install with JAX support
pip install transformer_engine[jax]

# Install from source with specific framework
NVTE_FRAMEWORK=pytorch pip install .

# Install optional dependencies for full feature set
pip install flash-attn>=2.1.1 triton

# For HuggingFace integration examples
pip install transformers accelerate

Code Evidence

Python minimum version check from `build_tools/utils.py:23-37`:

def min_python_version() -> Tuple[int]:
    """Minimum supported Python version."""
    return (3, 10, 0)

if sys.version_info < min_python_version():
    raise RuntimeError(
        f"Transformer Engine requires Python {min_python_version_str()} or newer, "
        f"but found Python {platform.python_version()}."
    )

PyTorch minimum version from `build_tools/pytorch.py:15-17`:

def install_requirements() -> List[str]:
    """Install dependencies for TE/PyTorch extensions."""
    return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic", "nvdlfw-inspect"]

CUDA 12.0+ requirement during PyTorch build from `build_tools/pytorch.py:68`:

if version < (12, 0):
    raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")

Flash Attention version requirements from `transformer_engine/pytorch/attention/dot_product_attention/utils.py:114-118`:

class FlashAttentionUtils:
    version_required = PkgVersion("2.1.1")          # Pre-Blackwell minimum
    version_required_blackwell = PkgVersion("2.7.3") # Blackwell (SM 10.0+)
    max_version = PkgVersion("2.8.3")                # Maximum supported

Common Errors

Error Message Cause Solution
`RuntimeError: Transformer Engine requires Python 3.10.0 or newer` Python version too old Upgrade to Python 3.10+
`RuntimeError: Transformer Engine requires CUDA 12.0 or newer` PyTorch built with old CUDA Install PyTorch with CUDA 12.x support
`ImportError: flash_attn not found` FlashAttention not installed `pip install flash-attn>=2.1.1`
`Unsupported cuda version X.Y` CUDA major version not 12 or 13 Use PyTorch built with CUDA 12.x or 13.x
`ImportError: No module named 'transformer_engine_torch'` C++ extension not built Rebuild: `pip install . --no-build-isolation`

Compatibility Notes

  • PyTorch 2.1+: Required minimum. Newer versions may enable additional features.
  • JAX 0.5.0+: Changed FFI module location from `jax.extend.ffi` to `jax.ffi`. TE handles both.
  • Flash Attention: Optional but strongly recommended. Version 2.7.3+ required for Blackwell GPUs. Maximum tested version is 2.8.3.
  • Triton: Can use either standard `triton` or `pytorch-triton` package (controlled by `NVTE_USE_PYTORCH_TRITON=1`).
  • pydantic: Used for recipe dataclass validation. Required at runtime for FP8 configuration.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment