Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Cpp Activation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Activation
Last Updated 2026-02-07 14:00 GMT

Overview

Implements JAX custom primitives for fused activation functions with optional quantization, including forward and backward passes that call into the TE C++ backend via FFI.

Description

This module maps string activation names (e.g., "gelu", ("silu", "linear")) to NVTE_Activation_Type enums. ActLuPrimitive implements the forward pass as a custom JAX primitive with abstract/lowering/partition methods for XLA compilation and SPMD sharding. DActLuDBiasQuantizePrimitive and DActLuQuantizePrimitive implement fused backward+quantize operations. ClampedSwigluParams supports parameterized activations. All primitives use BasePrimitive and register_primitive for JAX integration.

This is the core low-level activation layer that enables fused activation+quantization kernels on GPU, reducing memory traffic and improving performance for transformer MLP layers.

Usage

Use this module indirectly through the higher-level activation() function in transformer_engine.jax.activation or via LayerNormMLP. Direct usage is for advanced cases where fine-grained control over the activation primitive is needed, such as custom MLP implementations with FP8 quantization.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/cpp_extensions/activation.py
Lines
1--1656

Signature

class ClampedSwigluParams:
    limit: float = ...
    alpha: float = ...

class ActivationParams:
    activation_type: Sequence[Union[str, Callable]] = ...
    clamped_swiglu_params: Optional[ClampedSwigluParams] = None

class ActLuPrimitive(BasePrimitive): ...
class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...
class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...

def act_lu(
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Quantizer = None,
    act_params: Optional[ActivationParams] = None,
) -> Union[jnp.ndarray, ScaledTensor]: ...

def quantize_dact_dbias(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    quantizer: Quantizer = None,
    act_params: Optional[ActivationParams] = None,
) -> Tuple: ...

def dact_lu(
    dz: jnp.ndarray,
    x: jnp.ndarray,
    activation_type: Sequence[Union[str, Callable]],
    act_params: Optional[ActivationParams] = None,
) -> jnp.ndarray: ...

Import

from transformer_engine.jax.cpp_extensions.activation import act_lu, quantize_dact_dbias, dact_lu

I/O Contract

Inputs

Name Type Required Description
x jnp.ndarray Yes Input tensor for activation
activation_type Sequence[Union[str, Callable]] Yes Activation function name(s), e.g., ("gelu",) or ("silu", "linear")
quantizer Quantizer No Optional quantizer for FP8 output
act_params ActivationParams No Optional activation parameters (e.g., clamped SwiGLU limits)

Outputs

Name Type Description
output Union[jnp.ndarray, ScaledTensor] Activated tensor, optionally quantized to FP8

Usage Examples

from transformer_engine.jax.cpp_extensions.activation import act_lu, dact_lu

# Forward: apply GELU activation
output = act_lu(x, activation_type=("gelu",))

# Forward with SwiGLU (gated activation)
output = act_lu(x, activation_type=("silu", "linear"))

# Backward: compute activation gradient
dgrad = dact_lu(dz, x, activation_type=("gelu",))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment