Implementation:NVIDIA TransformerEngine JAX XLA GEMM
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements XLA FFI handlers for single GEMM, collective GEMM (all-gather/reduce-scatter overlapped), and grouped GEMM operations with FP8/MXFP8/NVFP4 quantization support for JAX.
Description
xla_buffer_to_nvte_gemm_operand converts XLA FFI buffers into TE TensorWrapper objects with 2D collapsed shapes and appropriate scaling metadata (tensor scaling, MXFP8 block scaling with pre-swizzled scales, or NVFP4 with on-the-fly swizzle via nvte_swizzle_scaling_factors). CollectiveGemmInitFFI runs at the Prepare stage to initialize cuBLAS handles and user-buffer communication infrastructure for the requested collective operation. GemmFFI (the main handler) constructs LHS/RHS operands, sets up output/bias/GELU tensors, and dispatches to either nvte_gemm for standard GEMM or collective GEMM executors for communication-overlapped operations. Grouped GEMM handlers support batched matrix multiplications with device-to-host group-size transfers. Pointer alignment to 256-byte boundaries is enforced for swizzle workspace buffers.
This is the largest and most complex C++ extension file, providing the core matrix multiplication operations that underpin all linear layers in JAX-based transformer models, with support for multiple quantization modes and multi-GPU tensor parallelism via communication overlap.
Usage
This C++ extension is invoked internally by the Python-side GemmPrimitive and GroupedGemmPrimitive in transformer_engine.jax.cpp_extensions.gemm. Users do not call these FFI handlers directly.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/csrc/extensions/gemm.cpp- Lines
- 1--814
Signature
namespace transformer_engine {
namespace jax {
std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv,
uint8_t *swizzle_scale_ptr, JAXX_Scaling_Mode scaling_mode,
size_t axis_boundary, bool rowwise);
Error_Type CollectiveGemmInitFFI(
cudaStream_t stream, int64_t collective_type, int64_t mesh_axis_size);
Error_Type GemmFFI(
cudaStream_t stream,
Buffer_Type lhs_buf, Buffer_Type lhs_scale_inv,
Buffer_Type rhs_buf, Buffer_Type rhs_scale_inv,
Result_Type output_buf, Result_Type bias_buf,
Result_Type workspace_buf,
int64_t lhs_axis_boundary, int64_t rhs_axis_boundary,
JAXX_Scaling_Mode lhs_scaling_mode, JAXX_Scaling_Mode rhs_scaling_mode,
bool lhs_is_rowwise, bool rhs_is_rowwise,
bool fuse_bias, int64_t collective_type, int64_t mesh_axis_size);
Error_Type GroupedGemmFFI(...);
} // namespace jax
} // namespace transformer_engine
Import
#include "transformer_engine/gemm.h"
#include "../extensions.h"
#include "cgemm_helper.h"
#include "nccl.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| lhs_buf | Buffer_Type |
Yes | Left-hand side GEMM operand buffer |
| lhs_scale_inv | Buffer_Type |
Yes | Inverse scale for LHS (FP8) |
| rhs_buf | Buffer_Type |
Yes | Right-hand side GEMM operand buffer |
| rhs_scale_inv | Buffer_Type |
Yes | Inverse scale for RHS (FP8) |
| lhs_scaling_mode | JAXX_Scaling_Mode |
Yes | Scaling mode for LHS operand |
| rhs_scaling_mode | JAXX_Scaling_Mode |
Yes | Scaling mode for RHS operand |
| fuse_bias | bool |
Yes | Whether to fuse bias addition with GEMM |
| collective_type | int64_t |
No | Collective operation type for tensor parallelism |
Outputs
| Name | Type | Description |
|---|---|---|
| output_buf | Result_Type |
GEMM output buffer |
| bias_buf | Result_Type |
Bias output buffer (if fuse_bias) |
Usage Examples
// This FFI handler is called internally by JAX's XLA compilation pipeline.
// Users interact with it through the Python API:
// from transformer_engine.jax.cpp_extensions.gemm import gemm
// output = gemm(lhs, rhs, dimension_numbers=((1,), (0,)))