Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX XLA Activation

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Activation
Last Updated 2026-02-07 14:00 GMT

Overview

Implements XLA FFI handlers for forward and backward activation functions (GELU, GEGLU, SiLU, SwiGLU, ReLU, REGLU, QGELU, QGEGLU, SRELU, SREGLU, Clamped SwiGLU) with optional FP8 quantized output, exposing them as custom JAX operations.

Description

ActLuFFI extracts input/output buffers from the XLA FFI call frame, constructs TensorWrapper objects with appropriate shapes and scaling metadata (delayed tensor scaling, block scaling, or no scaling), then dispatches to the corresponding nvte_* activation kernel (e.g., nvte_gelu, nvte_swiglu) via a switch on the activation enum. Supports rowwise and 2x2x (rowwise+columnwise) quantize layouts for FP8 output, setting scale and scale_inv buffers accordingly. A companion DActLuDBiasQuantizeFFI handles the fused backward pass (activation gradient + dbias + quantize). Both handlers are registered with XLA_FFI_DEFINE_HANDLER_SYMBOL along with Initialize variants for CUDA graph capture.

This extension enables JAX models to use GPU-accelerated fused activation kernels with FP8 precision support, avoiding the overhead of separate activation, bias, and quantization steps.

Usage

This C++ extension is invoked internally by the Python-side ActLuPrimitive and DActLuDBiasQuantizePrimitive in transformer_engine.jax.cpp_extensions.activation. Users do not call these FFI handlers directly.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/csrc/extensions/activation.cpp
Lines
1--514

Signature

namespace transformer_engine {
namespace jax {

Error_Type ActLuFFI(
    cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
    Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
    Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
    Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
    JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
    bool output_amax_when_no_scaling);

Error_Type DActLuDBiasQuantizeFFI(
    cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
    Buffer_Type scale_buf, Buffer_Type amax_buf,
    Result_Type output_buf, Result_Type colwise_output_buf,
    Result_Type dbias_buf, Result_Type scale_inv_buf,
    Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
    int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
    JAXX_Quantize_Layout quantize_layout, bool is_dbias,
    ActivationConfig act_params);

} // namespace jax
} // namespace transformer_engine

Import

#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "../extensions.h"

I/O Contract

Inputs

Name Type Required Description
input_buf Buffer_Type Yes Input tensor buffer
scale_buf Buffer_Type Yes Scale factor buffer for quantization
amax_buf Buffer_Type Yes Absolute maximum buffer for delayed scaling
act_enum int64_t Yes Activation type enum value (NVTE_Activation_Type)
scaling_mode JAXX_Scaling_Mode Yes Quantization scaling mode
quantize_layout JAXX_Quantize_Layout Yes Quantization layout (rowwise, colwise, or both)

Outputs

Name Type Description
output_buf Result_Type Activated output (rowwise), optionally FP8
colwise_output_buf Result_Type Column-wise activated output for 2x2x layout
scale_inv_buf Result_Type Inverse scale factor for output
updated_amax_buf Result_Type Updated amax value

Usage Examples

// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
//   from transformer_engine.jax.cpp_extensions.activation import act_lu
//   output = act_lu(x, activation_type=("gelu",))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment