Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:NVIDIA TransformerEngine Complete Transformer Layer

From Leeroopedia


Overview

Encapsulating an entire Transformer decoder/encoder layer (self-attention + FFN + residual connections) into a single optimized module.

Description

A complete Transformer layer combines self-attention, optional cross-attention, feed-forward network, layer normalization, residual connections, and dropout into a single module. TransformerEngine's implementation fuses all sub-components for maximum GPU utilization and FP8 support.

Rather than composing individual TE modules (LayerNormLinear, DotProductAttention, LayerNormMLP) manually, the complete Transformer layer module handles all of the internal wiring, residual connections, dropout, and parallelism configuration automatically. This provides:

  • Maximum fusion opportunity: All sub-operations are coordinated for optimal kernel scheduling and memory reuse.
  • Unified FP8 recipe management: A single FP8 context wraps the entire layer, with per-GEMM scaling factors managed automatically.
  • Simplified tensor/sequence/context parallelism: All parallel strategies are configured through a consistent interface.
  • Communication-GEMM overlap: The layer can overlap tensor-parallel all-reduce/all-gather communication with GEMM computation for maximum throughput.
  • Support for architectural variants: Encoder and decoder modes, grouped query attention (GQA), parallel attention+MLP, and various attention mask types.

Theoretical Basis

The standard pre-norm Transformer layer follows this computation:

Self-Attention Sub-Layer:

output = x + Dropout(Attention(LayerNorm(x)))

Feed-Forward Sub-Layer:

output = output + Dropout(MLP(LayerNorm(output)))

Expanded:

  1. LayerNorm1: norm1 = LayerNorm(x)
  2. QKV Projection: Q, K, V = W_qkv * norm1
  3. Attention: attn = softmax(Q * K^T / sqrt(d_k)) * V
  4. Output Projection: proj = W_out * attn
  5. Residual + Dropout: h = x + Dropout(proj)
  6. LayerNorm2: norm2 = LayerNorm(h)
  7. MLP: mlp = W2 * activation(W1 * norm2 + b1) + b2
  8. Residual + Dropout: output = h + Dropout(mlp)

For decoder mode, an additional cross-attention sub-layer is inserted between self-attention and MLP:

h2 = h + Dropout(CrossAttention(LayerNorm(h), encoder_output))

For parallel attention+MLP mode (used in architectures like Falcon), self-attention and MLP are computed in parallel from the same normalized input:

output = x + Dropout(Attention(LayerNorm1(x))) + Dropout(MLP(LayerNorm2(x)))

Usage

Use the complete Transformer layer module when:

  • You are constructing a Transformer model from scratch and want all TE optimizations out of the box.
  • You want FP8 training with automatic scaling factor management across all GEMMs in the layer.
  • You need tensor parallelism, sequence parallelism, or context parallelism with minimal configuration.
  • You want to overlap communication and computation for maximum throughput in distributed training.
  • You are building encoder, decoder, or encoder-decoder architectures (e.g., GPT, T5, BERT).
  • You need grouped query attention (GQA) as used in LLaMA-2, Mistral, and other modern architectures.

This module replaces hand-written layer implementations and serves as the primary building block for constructing Transformer models with TransformerEngine.

Related

Sources

Domains

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment