Implementation:NVIDIA TransformerEngine JAX Cpp Activation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Activation |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements JAX custom primitives for fused activation functions with optional quantization, including forward and backward passes that call into the TE C++ backend via FFI.
Description
This module maps string activation names (e.g., "gelu", ("silu", "linear")) to NVTE_Activation_Type enums. ActLuPrimitive implements the forward pass as a custom JAX primitive with abstract/lowering/partition methods for XLA compilation and SPMD sharding. DActLuDBiasQuantizePrimitive and DActLuQuantizePrimitive implement fused backward+quantize operations. ClampedSwigluParams supports parameterized activations. All primitives use BasePrimitive and register_primitive for JAX integration.
This is the core low-level activation layer that enables fused activation+quantization kernels on GPU, reducing memory traffic and improving performance for transformer MLP layers.
Usage
Use this module indirectly through the higher-level activation() function in transformer_engine.jax.activation or via LayerNormMLP. Direct usage is for advanced cases where fine-grained control over the activation primitive is needed, such as custom MLP implementations with FP8 quantization.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/cpp_extensions/activation.py- Lines
- 1--1656
Signature
class ClampedSwigluParams:
limit: float = ...
alpha: float = ...
class ActivationParams:
activation_type: Sequence[Union[str, Callable]] = ...
clamped_swiglu_params: Optional[ClampedSwigluParams] = None
class ActLuPrimitive(BasePrimitive): ...
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...
def act_lu(
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Quantizer = None,
act_params: Optional[ActivationParams] = None,
) -> Union[jnp.ndarray, ScaledTensor]: ...
def quantize_dact_dbias(
dz: jnp.ndarray,
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
quantizer: Quantizer = None,
act_params: Optional[ActivationParams] = None,
) -> Tuple: ...
def dact_lu(
dz: jnp.ndarray,
x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]],
act_params: Optional[ActivationParams] = None,
) -> jnp.ndarray: ...
Import
from transformer_engine.jax.cpp_extensions.activation import act_lu, quantize_dact_dbias, dact_lu
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | jnp.ndarray |
Yes | Input tensor for activation |
| activation_type | Sequence[Union[str, Callable]] |
Yes | Activation function name(s), e.g., ("gelu",) or ("silu", "linear")
|
| quantizer | Quantizer |
No | Optional quantizer for FP8 output |
| act_params | ActivationParams |
No | Optional activation parameters (e.g., clamped SwiGLU limits) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Union[jnp.ndarray, ScaledTensor] |
Activated tensor, optionally quantized to FP8 |
Usage Examples
from transformer_engine.jax.cpp_extensions.activation import act_lu, dact_lu
# Forward: apply GELU activation
output = act_lu(x, activation_type=("gelu",))
# Forward with SwiGLU (gated activation)
output = act_lu(x, activation_type=("silu", "linear"))
# Backward: compute activation gradient
dgrad = dact_lu(dz, x, activation_type=("gelu",))