Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Activation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Activation
Last Updated 2026-02-07 14:00 GMT

Overview

Provides optimized activation functions (GELU, SiLU, ReLU, etc.) with optional FP8 quantization support for the JAX backend.

Description

The public activation() function delegates to _activation, which uses jax.custom_vjp to define custom forward and backward rules. The forward rule calls tex.act_lu (a C++ extension primitive) with optional quantization, then dequantizes the output. The backward rule calls tex.dact_lu to compute activation gradients. The custom VJP ensures correct gradient flow through the fused activation+quantization pipeline.

This is a core functional building block used by higher-level modules (LayerNormMLP, Flax modules) to apply activation functions with FP8 quantization in a differentiable manner.

Usage

Use this function when applying activation functions with optional FP8 quantization in transformer MLP layers. It is typically called internally by layernorm_mlp and the Flax LayerNormMLP module.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/activation.py
Lines
1--109

Signature

def activation(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Optional[Quantizer] = None,
    act_params: Optional[tex.activation.ActivationParams] = None,
) -> jnp.ndarray: ...

Import

from transformer_engine.jax.activation import activation

I/O Contract

Inputs

Name Type Required Description
x jnp.ndarray Yes Input tensor to apply activations to
activation_type Sequence[Union[str, Callable]] Yes Activation functions, e.g., ("gelu",) or ("silu", "linear")
quantizer Optional[Quantizer] No Optional quantizer for FP8 quantized output
act_params Optional[ActivationParams] No Optional activation parameters (for clamped SwiGLU)

Outputs

Name Type Description
output jnp.ndarray Activated tensor (dequantized back to original precision)

Usage Examples

from transformer_engine.jax.activation import activation

# Apply GELU activation
output = activation(x, activation_type=("gelu",))

# Apply SwiGLU (gated activation) with FP8 quantization
output = activation(x, activation_type=("silu", "linear"), quantizer=fp8_quantizer)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment