Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE LayerNormMLP

From Leeroopedia


Overview

Concrete tool for fused layer normalization plus MLP provided by TransformerEngine.

Description

te.LayerNormMLP fuses LayerNorm + FC1 + Activation + FC2 into a single module. It supports gated activations (SwiGLU, GeGLU), FP8 quantization, tensor parallelism, and sequence parallelism. This is the highest-impact single fused module in TransformerEngine, as the MLP sub-layer contains the largest intermediate activations and the most FLOPs in a standard Transformer layer.

The module internally manages two weight matrices (FC1 and FC2), LayerNorm parameters (gamma, beta), and dispatches to optimized CUDA kernels that minimize global memory traffic between the sub-operations.

When using gated activations ("swiglu" or "geglu"), the FC1 weight matrix automatically doubles in output dimension to produce both the gate and value tensors, which are then combined element-wise before FC2.

Source

  • File: transformer_engine/pytorch/module/layernorm_mlp.py
  • Class: LayerNormMLP
  • Constructor: __init__ at lines L1764-1798

Import

from transformer_engine.pytorch import LayerNormMLP

or equivalently:

import transformer_engine.pytorch as te
te.LayerNormMLP

Signature

class LayerNormMLP(TransformerEngineBaseModule):
    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        return_bias: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        normalization: str = "LayerNorm",
        activation: str = "gelu",
        activation_params: Optional[dict] = None,
        output_layer_init_method: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        return_layernorm_output: bool = False,
        return_layernorm_output_gathered: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        set_parallel_mode: bool = False,
        zero_centered_gamma: bool = False,
        device: Union[torch.device, str] = "cuda",
        ub_overlap_ag: bool = False,
        name: Optional[str] = None,
        ub_overlap_rs: bool = False,
        ub_overlap_rs_dgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
        delay_wgrad_compute: bool = False,
        symmetric_ar_type: Optional[str] = None,
        checkpoint: bool = False,
    ) -> None:

I/O

  • Input: inp: torch.Tensor of shape [seq_length, batch_size, hidden_size].
  • Output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] of shape [seq_length, batch_size, hidden_size]. When return_layernorm_output=True, returns a tuple containing both the MLP output and the intermediate LayerNorm output. When return_bias=True, the bias is returned separately.

Key Parameters

Parameter Type Default Description
hidden_size int (required) Size of the input and output (model hidden dimension).
ffn_hidden_size int (required) Intermediate size of the MLP (typically 4x hidden_size).
eps float 1e-5 Epsilon for numerical stability in LayerNorm.
bias bool True Whether to add learnable biases to FC1 and FC2.
normalization str "LayerNorm" Type of normalization: "LayerNorm" or "RMSNorm".
activation str "gelu" Activation function. Options: "gelu", "geglu", "silu", "swiglu", "relu", "srelu", "qgelu".
activation_params Optional[dict] None Additional parameters for the activation function.
set_parallel_mode bool False If True, FC1 uses column parallelism and FC2 uses row parallelism.
return_layernorm_output bool False If True, also returns the intermediate LayerNorm output.
return_layernorm_output_gathered bool False If True (with sequence parallelism), returns the gathered LayerNorm output.
zero_centered_gamma bool False If True, the LayerNorm gamma parameter is centered at zero.
sequence_parallel bool False Whether to enable sequence parallelism.
tp_group Optional[dist_group_type] None Tensor parallel process group.
tp_size int 1 Tensor parallel world size.
fuse_wgrad_accumulation bool False Whether to fuse weight gradient accumulation into the WGRAD GEMM.
delay_wgrad_compute bool False Whether to delay weight gradient computation.
checkpoint bool False Whether to use selective activation checkpointing (recompute activations in backward, skipping FC2).
symmetric_ar_type Optional[str] None Type of symmetric memory all-reduce for the forward pass.

Example

Basic usage with GELU activation:

import transformer_engine.pytorch as te

mlp = te.LayerNormMLP(
    hidden_size=1024,
    ffn_hidden_size=4096,
    activation="gelu",
    normalization="LayerNorm",
)
output = mlp(hidden_states)

With SwiGLU gated activation (as used in LLaMA):

import transformer_engine.pytorch as te

mlp = te.LayerNormMLP(
    hidden_size=1024,
    ffn_hidden_size=4096,
    activation="swiglu",
    normalization="RMSNorm",
)
output = mlp(hidden_states)

With LayerNorm output returned for residual connections:

import transformer_engine.pytorch as te

mlp = te.LayerNormMLP(
    hidden_size=1024,
    ffn_hidden_size=4096,
    activation="gelu",
    return_layernorm_output=True,
)
mlp_output, ln_output = mlp(hidden_states)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment