Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine JAX Flax Module

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, JAX, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Defines Flax nn.Module wrappers around TE's core JAX operations, providing drop-in replacements for standard Flax layers with FP8 quantization and LoRA support.

Description

TransformerEngineBase manages FP8 quantization state (scales, amax history) through Flax's variable collection system. DenseGeneral wraps dense() with Flax parameter initialization and logical axis partitioning. LayerNormDenseGeneral and LayerNormMLP fuse normalization with linear layers. Each module creates QuantizerFactory instances from recipe configurations and passes them through the forward pass. make_dot_general_cls creates a custom lax.dot_general replacement for FP8 GEMM. LoRA support is integrated via _apply_low_rank_adaptation.

Additional modules include Softmax for fused softmax operations and LayerNorm for standalone normalization.

This is the primary user-facing API for Flax users, providing FP8-accelerated transformer building blocks that integrate seamlessly with Flax's module system, parameter management, and logical axis partitioning.

Usage

Use these modules as drop-in replacements for standard Flax layers in transformer architectures. Enable FP8 quantization by using the autocast context manager and passing the appropriate quantization recipe.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/jax/flax/module.py
Lines
1--1440

Signature

class Softmax(nn.Module):
    """Fused softmax module with multiple fusion types."""
    ...

class LayerNorm(nn.Module):
    """Layer normalization with optional FP8 quantization."""
    ...

class TransformerEngineBase(nn.Module):
    """Base class managing FP8 quantization state."""
    ...

class DenseGeneral(TransformerEngineBase):
    """Dense layer with FP8 support and LoRA."""
    features: Union[Iterable[int], int] = ...
    use_bias: bool = True
    enable_low_rank_adaptation: bool = False
    ...

class LayerNormDenseGeneral(TransformerEngineBase):
    """Fused LayerNorm + Dense with FP8 support."""
    features: Union[Iterable[int], int] = ...
    ...

class LayerNormMLP(TransformerEngineBase):
    """Fused LayerNorm + MLP block with FP8 support."""
    intermediate_dim: int = 2048
    activations: Sequence[Union[str, Callable]] = ("relu",)
    ...

def wrap_function_in_te_state_module(f, quantization_recipe, name=None): ...
def make_dot_general_cls(quantization_recipe): ...

Import

from transformer_engine.jax.flax.module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP

I/O Contract

Inputs

Name Type Required Description
x jnp.ndarray Yes Input tensor to the module
features Union[Iterable[int], int] Yes (init) Output feature dimensions
use_bias bool No Whether to include a bias term (default True)
enable_low_rank_adaptation bool No Whether to enable LoRA (default False)
axis Union[Iterable[int], int] No Contracting axis for the kernel (default -1)

Outputs

Name Type Description
output jnp.ndarray Transformed output tensor

Usage Examples

from transformer_engine.jax.flax.module import DenseGeneral, LayerNormMLP
import flax.linen as nn

class MyTransformerBlock(nn.Module):
    hidden_dim: int = 512
    mlp_dim: int = 2048

    @nn.compact
    def __call__(self, x):
        # Fused LayerNorm + MLP with FP8 quantization
        x = LayerNormMLP(
            intermediate_dim=self.mlp_dim,
            activations=("gelu",),
        )(x)
        return x

# Enable FP8 with autocast
from transformer_engine.jax.quantize import autocast
with autocast(recipe=fp8_recipe, mesh_resource=mesh_resource):
    output = model.apply(params, input_data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment