Implementation:NVIDIA TransformerEngine PyTorch Ext Activation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements fused activation function forward and backward passes (GeLU, SiLU, ReLU, QuickGELU, SReLU, and their gated variants) with optional FP8/MXFP8/NVFP4 quantization.
Description
Uses templated helpers activation_helper and dactivation_helper that accept NVTE kernel function pointers. For forward, selects between four implementation paths based on the quantizer type: UNFUSED (activation then separate quantize), FULLY_FUSED (single fused kernel for delayed scaling / MXFP8), FUSED_ACTIVATION_AMAX_FP8 (compute activation + amax then quantize for current scaling), and FUSED_ACTIVATION_AMAX_NVFP4 (similar for NVFP4). Each public function (gelu, geglu, relu, reglu, silu, swiglu, qgelu, qgeglu, srelu, sreglu, etc.) instantiates the template with the corresponding NVTE activation kernel.
Usage
Called by the PyTorch activation operation classes to perform fused activation + quantization in a single GPU kernel launch, reducing memory traffic.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/activation.cpp- Lines
- 1--333
Signature
namespace transformer_engine::pytorch {
py::object gelu(at::Tensor input, py::handle quantizer);
py::object geglu(at::Tensor input, py::handle quantizer);
py::object silu(at::Tensor input, py::handle quantizer);
py::object swiglu(at::Tensor input, py::handle quantizer);
py::object relu(at::Tensor input, py::handle quantizer);
py::object reglu(at::Tensor input, py::handle quantizer);
py::object qgelu(at::Tensor input, py::handle quantizer);
py::object qgeglu(at::Tensor input, py::handle quantizer);
py::object srelu(at::Tensor input, py::handle quantizer);
py::object sreglu(at::Tensor input, py::handle quantizer);
at::Tensor dgelu(at::Tensor grad, at::Tensor input);
at::Tensor dsilu(at::Tensor grad, at::Tensor input);
at::Tensor drelu(at::Tensor grad, at::Tensor input);
// ... and gated backward variants
}
Import
#include "../extensions.h"
#include "common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | at::Tensor |
Yes | Input tensor to apply activation to |
| quantizer | py::handle |
No | Optional quantizer for fused FP8 output |
| grad | at::Tensor |
No | Gradient tensor (for backward functions) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | py::object |
Activated (and optionally quantized) tensor |
| grad_input | at::Tensor |
Gradient w.r.t. input (backward functions) |
Usage Examples
import transformer_engine_torch as tex
# Fused GeLU with FP8 quantization
output_fp8 = tex.gelu(input_tensor, fp8_quantizer)
# GeLU backward
grad_input = tex.dgelu(grad_output, saved_input)