Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX XLA GEMM

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements XLA FFI handlers for single GEMM, collective GEMM (all-gather/reduce-scatter overlapped), and grouped GEMM operations with FP8/MXFP8/NVFP4 quantization support for JAX.

Description

xla_buffer_to_nvte_gemm_operand converts XLA FFI buffers into TE TensorWrapper objects with 2D collapsed shapes and appropriate scaling metadata (tensor scaling, MXFP8 block scaling with pre-swizzled scales, or NVFP4 with on-the-fly swizzle via nvte_swizzle_scaling_factors). CollectiveGemmInitFFI runs at the Prepare stage to initialize cuBLAS handles and user-buffer communication infrastructure for the requested collective operation. GemmFFI (the main handler) constructs LHS/RHS operands, sets up output/bias/GELU tensors, and dispatches to either nvte_gemm for standard GEMM or collective GEMM executors for communication-overlapped operations. Grouped GEMM handlers support batched matrix multiplications with device-to-host group-size transfers. Pointer alignment to 256-byte boundaries is enforced for swizzle workspace buffers.

This is the largest and most complex C++ extension file, providing the core matrix multiplication operations that underpin all linear layers in JAX-based transformer models, with support for multiple quantization modes and multi-GPU tensor parallelism via communication overlap.

Usage

This C++ extension is invoked internally by the Python-side GemmPrimitive and GroupedGemmPrimitive in transformer_engine.jax.cpp_extensions.gemm. Users do not call these FFI handlers directly.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/csrc/extensions/gemm.cpp
Lines
1--814

Signature

namespace transformer_engine {
namespace jax {

std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
    cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv,
    uint8_t *swizzle_scale_ptr, JAXX_Scaling_Mode scaling_mode,
    size_t axis_boundary, bool rowwise);

Error_Type CollectiveGemmInitFFI(
    cudaStream_t stream, int64_t collective_type, int64_t mesh_axis_size);

Error_Type GemmFFI(
    cudaStream_t stream,
    Buffer_Type lhs_buf, Buffer_Type lhs_scale_inv,
    Buffer_Type rhs_buf, Buffer_Type rhs_scale_inv,
    Result_Type output_buf, Result_Type bias_buf,
    Result_Type workspace_buf,
    int64_t lhs_axis_boundary, int64_t rhs_axis_boundary,
    JAXX_Scaling_Mode lhs_scaling_mode, JAXX_Scaling_Mode rhs_scaling_mode,
    bool lhs_is_rowwise, bool rhs_is_rowwise,
    bool fuse_bias, int64_t collective_type, int64_t mesh_axis_size);

Error_Type GroupedGemmFFI(...);

} // namespace jax
} // namespace transformer_engine

Import

#include "transformer_engine/gemm.h"
#include "../extensions.h"
#include "cgemm_helper.h"
#include "nccl.h"

I/O Contract

Inputs

Name Type Required Description
lhs_buf Buffer_Type Yes Left-hand side GEMM operand buffer
lhs_scale_inv Buffer_Type Yes Inverse scale for LHS (FP8)
rhs_buf Buffer_Type Yes Right-hand side GEMM operand buffer
rhs_scale_inv Buffer_Type Yes Inverse scale for RHS (FP8)
lhs_scaling_mode JAXX_Scaling_Mode Yes Scaling mode for LHS operand
rhs_scaling_mode JAXX_Scaling_Mode Yes Scaling mode for RHS operand
fuse_bias bool Yes Whether to fuse bias addition with GEMM
collective_type int64_t No Collective operation type for tensor parallelism

Outputs

Name Type Description
output_buf Result_Type GEMM output buffer
bias_buf Result_Type Bias output buffer (if fuse_bias)

Usage Examples

// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
//   from transformer_engine.jax.cpp_extensions.gemm import gemm
//   output = gemm(lhs, rhs, dimension_numbers=((1,), (0,)))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment