Implementation:NVIDIA TransformerEngine JAX Flax Module
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, JAX, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Defines Flax nn.Module wrappers around TE's core JAX operations, providing drop-in replacements for standard Flax layers with FP8 quantization and LoRA support.
Description
TransformerEngineBase manages FP8 quantization state (scales, amax history) through Flax's variable collection system. DenseGeneral wraps dense() with Flax parameter initialization and logical axis partitioning. LayerNormDenseGeneral and LayerNormMLP fuse normalization with linear layers. Each module creates QuantizerFactory instances from recipe configurations and passes them through the forward pass. make_dot_general_cls creates a custom lax.dot_general replacement for FP8 GEMM. LoRA support is integrated via _apply_low_rank_adaptation.
Additional modules include Softmax for fused softmax operations and LayerNorm for standalone normalization.
This is the primary user-facing API for Flax users, providing FP8-accelerated transformer building blocks that integrate seamlessly with Flax's module system, parameter management, and logical axis partitioning.
Usage
Use these modules as drop-in replacements for standard Flax layers in transformer architectures. Enable FP8 quantization by using the autocast context manager and passing the appropriate quantization recipe.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/jax/flax/module.py- Lines
- 1--1440
Signature
class Softmax(nn.Module):
"""Fused softmax module with multiple fusion types."""
...
class LayerNorm(nn.Module):
"""Layer normalization with optional FP8 quantization."""
...
class TransformerEngineBase(nn.Module):
"""Base class managing FP8 quantization state."""
...
class DenseGeneral(TransformerEngineBase):
"""Dense layer with FP8 support and LoRA."""
features: Union[Iterable[int], int] = ...
use_bias: bool = True
enable_low_rank_adaptation: bool = False
...
class LayerNormDenseGeneral(TransformerEngineBase):
"""Fused LayerNorm + Dense with FP8 support."""
features: Union[Iterable[int], int] = ...
...
class LayerNormMLP(TransformerEngineBase):
"""Fused LayerNorm + MLP block with FP8 support."""
intermediate_dim: int = 2048
activations: Sequence[Union[str, Callable]] = ("relu",)
...
def wrap_function_in_te_state_module(f, quantization_recipe, name=None): ...
def make_dot_general_cls(quantization_recipe): ...
Import
from transformer_engine.jax.flax.module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | jnp.ndarray |
Yes | Input tensor to the module |
| features | Union[Iterable[int], int] |
Yes (init) | Output feature dimensions |
| use_bias | bool |
No | Whether to include a bias term (default True) |
| enable_low_rank_adaptation | bool |
No | Whether to enable LoRA (default False) |
| axis | Union[Iterable[int], int] |
No | Contracting axis for the kernel (default -1) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | jnp.ndarray |
Transformed output tensor |
Usage Examples
from transformer_engine.jax.flax.module import DenseGeneral, LayerNormMLP
import flax.linen as nn
class MyTransformerBlock(nn.Module):
hidden_dim: int = 512
mlp_dim: int = 2048
@nn.compact
def __call__(self, x):
# Fused LayerNorm + MLP with FP8 quantization
x = LayerNormMLP(
intermediate_dim=self.mlp_dim,
activations=("gelu",),
)(x)
return x
# Enable FP8 with autocast
from transformer_engine.jax.quantize import autocast
with autocast(recipe=fp8_recipe, mesh_resource=mesh_resource):
output = model.apply(params, input_data)