Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE TransformerLayer

From Leeroopedia


Overview

Concrete tool for a complete optimized Transformer layer provided by TransformerEngine.

Description

te.TransformerLayer is the highest-level module in TransformerEngine, combining MultiheadAttention + LayerNormMLP + LayerNorm + residual connections + dropout into a single, fully optimized Transformer layer. It supports both encoder and decoder modes, FP8 quantization, tensor parallelism, sequence parallelism, context parallelism, and communication-GEMM overlap.

This module internally instantiates:

  • A MultiheadAttention module for self-attention (with fused QKV projection).
  • Optionally a second MultiheadAttention module for cross-attention (in decoder mode).
  • A LayerNormMLP module for the feed-forward network.
  • LayerNorm layers for normalization.
  • Dropout layers and residual connection logic.

All sub-modules are configured consistently and dispatch to optimized CUDA kernels.

Source

  • File: transformer_engine/pytorch/transformer.py
  • Class: TransformerLayer at lines L70-945
  • Constructor: __init__ at lines L299-353

Import

from transformer_engine.pytorch import TransformerLayer

or equivalently:

import transformer_engine.pytorch as te
te.TransformerLayer

Signature

class TransformerLayer(torch.nn.Module):
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        num_gqa_groups: Optional[int] = None,
        layernorm_epsilon: float = 1e-5,
        hidden_dropout: float = 0.1,
        attention_dropout: float = 0.1,
        init_method: Optional[Callable] = None,
        output_layer_init_method: Optional[Callable] = None,
        layer_number: Optional[int] = None,
        kv_channels: Optional[int] = None,
        self_attn_mask_type: str = "causal",
        window_size: Optional[Tuple[int, int]] = None,
        bottom_right_diagonal: Optional[bool] = None,
        enc_dec_attn_mask_type: str = "no_mask",
        enc_dec_bottom_right_diagonal: Optional[bool] = None,
        enc_dec_window_size: Optional[Tuple[int, int]] = None,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        params_dtype: Optional[torch.dtype] = None,
        get_rng_state_tracker: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        sequence_parallel: bool = False,
        apply_residual_connection_post_layernorm: bool = False,
        output_layernorm: bool = False,
        parallel_attention_mlp: bool = False,
        layer_type: str = "encoder",
        drop_path_rate: float = 0.0,
        set_parallel_mode: bool = False,
        fuse_qkv_params: bool = False,
        rotary_pos_interleaved: bool = False,
        zero_centered_gamma: bool = False,
        qkv_weight_interleaved: bool = True,
        ub_tp_comm_overlap: bool = False,
        ub_overlap_ag: bool = True,
        ub_overlap_rs: bool = True,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = True,
        ub_bulk_wgrad: bool = True,
        bias: bool = True,
        activation: str = "gelu",
        activation_params: Optional[dict] = None,
        normalization: str = "LayerNorm",
        device: Union[torch.device, str] = "cuda",
        attn_input_format: str = "sbhd",
        name: str = None,
        qk_norm_type: Optional[str] = None,
        qk_norm_eps: float = 1e-6,
        qk_norm_before_rope: bool = False,
        softmax_type: str = "vanilla",
    ) -> None:

I/O

  • Input:
    • hidden_states: torch.Tensor of shape [s, b, h] (sequence, batch, hidden_size) -- the primary input.
    • attention_mask: Optional[torch.Tensor] -- attention mask (used with "padding" or "arbitrary" mask types).
    • encoder_output: Optional[torch.Tensor] -- encoder hidden states for cross-attention (decoder mode only).
    • rotary_pos_emb: Optional[torch.Tensor] -- rotary positional embedding tensors.
    • inference_params: Optional -- parameters for inference-time KV caching.
  • Output: torch.Tensor of shape [s, b, h] -- the Transformer layer output.

Key Parameters

Parameter Type Default Description
hidden_size int (required) Size of hidden representations (model dimension).
ffn_hidden_size int (required) Intermediate size of the MLP (typically 4x hidden_size).
num_attention_heads int (required) Number of attention heads.
num_gqa_groups Optional[int] None Number of GQA groups. None defaults to MHA. Set to 1 for MQA.
layernorm_epsilon float 1e-5 Epsilon for LayerNorm numerical stability.
hidden_dropout float 0.1 Dropout probability after FC2 and output projection.
attention_dropout float 0.1 Dropout probability in attention softmax.
self_attn_mask_type str "causal" Attention mask type: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "arbitrary".
activation str "gelu" MLP activation: "gelu", "geglu", "silu", "swiglu", "relu", "srelu", "qgelu".
normalization str "LayerNorm" Normalization type: "LayerNorm" or "RMSNorm".
layer_type str "encoder" Layer type: "encoder" (self-attention only) or "decoder" (self-attention + cross-attention).
fuse_qkv_params bool False Whether to fuse Q, K, V weight matrices into a single parameter.
parallel_attention_mlp bool False If True, attention and MLP run in parallel (Falcon architecture).
apply_residual_connection_post_layernorm bool False If True, residual is taken from LayerNorm output rather than input.
output_layernorm bool False If True, apply LayerNorm on the output side instead of input side.
zero_centered_gamma bool False If True, LayerNorm gamma is centered at zero.
attn_input_format str "sbhd" Attention input format: "sbhd" or "bshd".
set_parallel_mode bool False If True, enables tensor parallelism for all sub-modules.
sequence_parallel bool False Whether to enable sequence parallelism.
tp_group Optional[dist_group_type] None Tensor parallel process group.
tp_size int 1 Tensor parallel world size.

Communication-GEMM overlap parameters (require ub_tp_comm_overlap=True):

Parameter Default Description
ub_tp_comm_overlap False Master switch to enable communication-GEMM overlap.
ub_overlap_ag True Overlap all-gather with GEMM in forward pass.
ub_overlap_rs True Overlap reduce-scatter with GEMM in forward pass.
ub_overlap_rs_dgrad False Overlap reduce-scatter with DGRAD GEMM in backward pass.
ub_bulk_dgrad True Bulk overlap for DGRAD communication.
ub_bulk_wgrad True Bulk overlap for WGRAD communication.

Example

Basic causal decoder layer (GPT-style):

import transformer_engine.pytorch as te

layer = te.TransformerLayer(
    hidden_size=1024,
    ffn_hidden_size=4096,
    num_attention_heads=16,
    self_attn_mask_type="causal",
    activation="gelu",
    normalization="LayerNorm",
)
output = layer(hidden_states)

With grouped query attention (GQA) and SwiGLU (LLaMA-style):

import transformer_engine.pytorch as te

layer = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=11008,
    num_attention_heads=32,
    num_gqa_groups=8,
    self_attn_mask_type="causal",
    activation="swiglu",
    normalization="RMSNorm",
    bias=False,
)
output = layer(hidden_states)

Encoder-decoder (T5-style) with cross-attention:

import transformer_engine.pytorch as te

decoder_layer = te.TransformerLayer(
    hidden_size=1024,
    ffn_hidden_size=4096,
    num_attention_heads=16,
    layer_type="decoder",
    self_attn_mask_type="causal",
)
output = decoder_layer(
    hidden_states,
    encoder_output=encoder_hidden_states,
)

With tensor parallelism and communication overlap:

import transformer_engine.pytorch as te

layer = te.TransformerLayer(
    hidden_size=1024,
    ffn_hidden_size=4096,
    num_attention_heads=16,
    set_parallel_mode=True,
    tp_group=tp_group,
    tp_size=8,
    sequence_parallel=True,
    ub_tp_comm_overlap=True,
)
output = layer(hidden_states)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment