Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:NVIDIA TransformerEngine FP8 Training Quickstart

From Leeroopedia


Knowledge Sources
Domains LLMs, FP8_Training, Mixed_Precision
Last Updated 2026-02-07 21:00 GMT

Overview

End-to-end process for building a Transformer decoder layer with FP8 precision using NVIDIA Transformer Engine, progressively optimizing from pure PyTorch to fully fused TE modules.

Description

This workflow demonstrates the standard path for adopting Transformer Engine in a PyTorch-based Transformer training pipeline. Starting from a baseline implementation using native PyTorch modules, it shows how to incrementally replace components with TE equivalents, enable FP8 precision via the autocast context manager, and leverage fused operations for maximum performance. The process covers five stages of optimization: basic module replacement, attention backend replacement, FP8 enablement, kernel fusion, and using the ready-made TransformerLayer module.

Key outputs:

  • A Transformer decoder layer running with FP8 mixed precision
  • Performance benchmarks at each optimization stage
  • Understanding of TE module hierarchy and FP8 recipe configuration

Usage

Execute this workflow when you want to accelerate Transformer model training on NVIDIA Hopper, Ada, or Blackwell GPUs using FP8 precision. This is the recommended starting point for any new TE adoption, whether you are building a model from scratch or migrating an existing PyTorch Transformer implementation to use TE's optimized modules and FP8 support.

Execution Steps

Step 1: Establish Baseline Transformer Layer

Build a standard Transformer decoder layer using native PyTorch modules. This includes LayerNorm, Linear projections for Q/K/V, a dot-product attention implementation, an output projection, and a two-layer MLP with GELU activation. The baseline establishes a performance reference point for measuring speedups from TE optimizations.

Key considerations:

  • Use BF16 or FP16 as the baseline precision for fair comparison
  • Structure the layer with pre-norm residual connections (GPT-style)
  • Fuse Q, K, V into a single 3x-wide Linear projection for efficiency

Step 2: Replace PyTorch Modules With TE Equivalents

Swap standard PyTorch modules with their Transformer Engine counterparts. Replace torch.nn.Linear with te.Linear and torch.nn.LayerNorm with te.LayerNorm. These TE modules are drop-in replacements that internally manage FP8 scaling factors and quantization metadata, preparing the model for FP8 enablement in later steps.

What changes:

  • torch.nn.Linear becomes te.Linear
  • torch.nn.LayerNorm becomes te.LayerNorm
  • Model behavior remains identical at this stage (no FP8 yet)

Step 3: Replace Attention With TE DotProductAttention

Replace the custom or framework-native attention implementation with te.DotProductAttention. This module automatically selects the best available backend (FlashAttention-2, FlashAttention-3, or cuDNN fused attention) based on hardware capability and input configuration. It supports causal masking, padding masking, and arbitrary attention masks.

Key considerations:

  • Set attn_mask_type to specify the masking behavior (e.g., "causal")
  • TE attention handles head dimension reshaping internally
  • Backend selection is automatic but can be controlled via environment variables

Step 4: Enable FP8 Precision With Autocast

Wrap the forward pass in a te.autocast context manager with an FP8 recipe to enable 8-bit floating point computation. Configure the recipe to control the FP8 format (E4M3 for forward, E5M2 for backward in HYBRID mode), scaling strategy (delayed scaling with amax history or current scaling), and other quantization parameters.

Key considerations:

  • The autocast context must wrap only the forward pass and must exit before backward
  • DelayedScaling uses a history window of amax values for stable scaling
  • Float8CurrentScaling computes the scale factor from the current tensor on each pass
  • FP8 requires Compute Capability 8.9+ (Ada, Hopper, Blackwell)

Step 5: Use Fused TE Modules

Replace separate LayerNorm and Linear/MLP modules with TE's fused equivalents: te.LayerNormLinear (fuses LayerNorm + Linear) and te.LayerNormMLP (fuses LayerNorm + two-layer MLP). Fused modules reduce kernel launch overhead and memory traffic by combining multiple operations into single GPU kernels.

What changes:

  • Separate LayerNorm + QKV Linear becomes te.LayerNormLinear
  • Separate LayerNorm + MLP becomes te.LayerNormMLP
  • Benefits scale with multi-GPU setups where kernel launch overhead is more significant

Step 6: Use TransformerLayer Module

For the simplest integration, use te.TransformerLayer which is a ready-made, fully optimized Transformer decoder layer. It includes all previous optimizations (fused modules, optimized attention, FP8 support) out of the box. Configure it with hidden_size, ffn_hidden_size, num_attention_heads, and mask type parameters.

Key considerations:

  • TransformerLayer is the recommended module for most use cases
  • Supports GQA (Grouped Query Attention) via num_gqa_groups parameter
  • Supports RoPE via rotary_pos_emb parameter
  • Supports various normalization types (LayerNorm, RMSNorm)
  • Supports SwiGLU and other activation functions

Execution Diagram

GitHub URL

Workflow Repository