Implementation:NVIDIA TransformerEngine JAX Activation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Activation |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Provides optimized activation functions (GELU, SiLU, ReLU, etc.) with optional FP8 quantization support for the JAX backend.
Description
The public activation() function delegates to _activation, which uses jax.custom_vjp to define custom forward and backward rules. The forward rule calls tex.act_lu (a C++ extension primitive) with optional quantization, then dequantizes the output. The backward rule calls tex.dact_lu to compute activation gradients. The custom VJP ensures correct gradient flow through the fused activation+quantization pipeline.
This is a core functional building block used by higher-level modules (LayerNormMLP, Flax modules) to apply activation functions with FP8 quantization in a differentiable manner.
Usage
Use this function when applying activation functions with optional FP8 quantization in transformer MLP layers. It is typically called internally by layernorm_mlp and the Flax LayerNormMLP module.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/activation.py- Lines
- 1--109
Signature
def activation(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None,
act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray: ...
Import
from transformer_engine.jax.activation import activation
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | jnp.ndarray |
Yes | Input tensor to apply activations to |
| activation_type | Sequence[Union[str, Callable]] |
Yes | Activation functions, e.g., ("gelu",) or ("silu", "linear")
|
| quantizer | Optional[Quantizer] |
No | Optional quantizer for FP8 quantized output |
| act_params | Optional[ActivationParams] |
No | Optional activation parameters (for clamped SwiGLU) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Activated tensor (dequantized back to original precision) |
Usage Examples
from transformer_engine.jax.activation import activation
# Apply GELU activation
output = activation(x, activation_type=("gelu",))
# Apply SwiGLU (gated activation) with FP8 quantization
output = activation(x, activation_type=("silu", "linear"), quantizer=fp8_quantizer)