Implementation:NVIDIA TransformerEngine JIT
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides JIT compilation utilities and fused activation function implementations using torch.compile and NVFuser for reduced kernel launch overhead and memory traffic.
Description
The jit_fuser decorator defaults to lazy_compile (deferred torch.compile) when PyTorch >= 2.0 and NVTE_TORCH_COMPILE=1, enabling kernel fusion at first call. dropout_fuser similarly uses torch.compile for PyTorch >= 2.2 and falls back to torch.jit.script. no_torch_dynamo is a decorator that disables Torch Dynamo tracing for specific functions (except during ONNX export), working around compilation issues. set_jit_fusion_options configures NVFuser or legacy fuser flags depending on the PyTorch version. The file implements numerous fused activation functions decorated with @jit_fuser: bias+GeLU (forward and backward), bias+SiLU, bias+ReLU, bias+GELU(quick), bias+SReLU, L2 normalization, and their gradient computations.
Usage
Import and use the fused activation functions throughout transformer layers for reduced memory traffic and kernel launch overhead. Use set_jit_fusion_options at initialization time.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/jit.py- Lines
- 1--401
Signature
def lazy_compile(func): ...
def set_jit_fusion_options() -> None: ...
def no_torch_dynamo(func): ...
@jit_fuser
def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: ...
@jit_fuser
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor: ...
@jit_fuser
def bgrad_dgelu_fused_(grad_output, inp, bias) -> Tuple: ...
@jit_fuser
def bias_dropout_add_fused_train_(x, bias, residual, prob) -> torch.Tensor: ...
def get_bias_dropout_add(training: bool) -> Callable: ...
Import
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
no_torch_dynamo,
jit_fuser,
bias_gelu_fused_,
get_bias_dropout_add,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inp | torch.Tensor |
Yes | Input tensor for activation |
| bias | torch.Tensor |
No | Bias tensor to fuse with activation |
| residual | torch.Tensor |
No | Residual tensor for dropout+add |
| prob | float |
No | Dropout probability |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Result of the fused operation (activation, dropout+add, etc.) |
Usage Examples
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
bias_gelu_fused_,
get_bias_dropout_add,
)
# Initialize JIT fusion options
set_jit_fusion_options()
# Fused bias + GeLU
output = bias_gelu_fused_(input_tensor, bias)
# Get fused bias + dropout + add function
bias_dropout_add_fn = get_bias_dropout_add(training=True)
output = bias_dropout_add_fn(x, bias, residual, dropout_prob)