Implementation:NVIDIA TransformerEngine JAX XLA Activation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Activation |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements XLA FFI handlers for forward and backward activation functions (GELU, GEGLU, SiLU, SwiGLU, ReLU, REGLU, QGELU, QGEGLU, SRELU, SREGLU, Clamped SwiGLU) with optional FP8 quantized output, exposing them as custom JAX operations.
Description
ActLuFFI extracts input/output buffers from the XLA FFI call frame, constructs TensorWrapper objects with appropriate shapes and scaling metadata (delayed tensor scaling, block scaling, or no scaling), then dispatches to the corresponding nvte_* activation kernel (e.g., nvte_gelu, nvte_swiglu) via a switch on the activation enum. Supports rowwise and 2x2x (rowwise+columnwise) quantize layouts for FP8 output, setting scale and scale_inv buffers accordingly. A companion DActLuDBiasQuantizeFFI handles the fused backward pass (activation gradient + dbias + quantize). Both handlers are registered with XLA_FFI_DEFINE_HANDLER_SYMBOL along with Initialize variants for CUDA graph capture.
This extension enables JAX models to use GPU-accelerated fused activation kernels with FP8 precision support, avoiding the overhead of separate activation, bias, and quantization steps.
Usage
This C++ extension is invoked internally by the Python-side ActLuPrimitive and DActLuDBiasQuantizePrimitive in transformer_engine.jax.cpp_extensions.activation. Users do not call these FFI handlers directly.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/csrc/extensions/activation.cpp- Lines
- 1--514
Signature
namespace transformer_engine {
namespace jax {
Error_Type ActLuFFI(
cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params,
bool output_amax_when_no_scaling);
Error_Type DActLuDBiasQuantizeFFI(
cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type scale_buf, Buffer_Type amax_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type dbias_buf, Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
int64_t act_enum, JAXX_Scaling_Mode scaling_mode,
JAXX_Quantize_Layout quantize_layout, bool is_dbias,
ActivationConfig act_params);
} // namespace jax
} // namespace transformer_engine
Import
#include "transformer_engine/activation.h"
#include "transformer_engine/cast.h"
#include "../extensions.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_buf | Buffer_Type |
Yes | Input tensor buffer |
| scale_buf | Buffer_Type |
Yes | Scale factor buffer for quantization |
| amax_buf | Buffer_Type |
Yes | Absolute maximum buffer for delayed scaling |
| act_enum | int64_t |
Yes | Activation type enum value (NVTE_Activation_Type) |
| scaling_mode | JAXX_Scaling_Mode |
Yes | Quantization scaling mode |
| quantize_layout | JAXX_Quantize_Layout |
Yes | Quantization layout (rowwise, colwise, or both) |
Outputs
| Name | Type | Description |
|---|---|---|
| output_buf | Result_Type |
Activated output (rowwise), optionally FP8 |
| colwise_output_buf | Result_Type |
Column-wise activated output for 2x2x layout |
| scale_inv_buf | Result_Type |
Inverse scale factor for output |
| updated_amax_buf | Result_Type |
Updated amax value |
Usage Examples
// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
// from transformer_engine.jax.cpp_extensions.activation import act_lu
// output = act_lu(x, activation_type=("gelu",))