Implementation:NVIDIA TransformerEngine TE LayerNormMLP
Overview
Concrete tool for fused layer normalization plus MLP provided by TransformerEngine.
Description
te.LayerNormMLP fuses LayerNorm + FC1 + Activation + FC2 into a single module. It supports gated activations (SwiGLU, GeGLU), FP8 quantization, tensor parallelism, and sequence parallelism. This is the highest-impact single fused module in TransformerEngine, as the MLP sub-layer contains the largest intermediate activations and the most FLOPs in a standard Transformer layer.
The module internally manages two weight matrices (FC1 and FC2), LayerNorm parameters (gamma, beta), and dispatches to optimized CUDA kernels that minimize global memory traffic between the sub-operations.
When using gated activations ("swiglu" or "geglu"), the FC1 weight matrix automatically doubles in output dimension to produce both the gate and value tensors, which are then combined element-wise before FC2.
Source
- File:
transformer_engine/pytorch/module/layernorm_mlp.py - Class:
LayerNormMLP - Constructor:
__init__at lines L1764-1798
Import
from transformer_engine.pytorch import LayerNormMLP
or equivalently:
import transformer_engine.pytorch as te
te.LayerNormMLP
Signature
class LayerNormMLP(TransformerEngineBaseModule):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
return_bias: bool = False,
get_rng_state_tracker: Optional[Callable] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
init_method: Optional[Callable] = None,
bias: bool = True,
normalization: str = "LayerNorm",
activation: str = "gelu",
activation_params: Optional[dict] = None,
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None,
return_layernorm_output: bool = False,
return_layernorm_output_gathered: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
set_parallel_mode: bool = False,
zero_centered_gamma: bool = False,
device: Union[torch.device, str] = "cuda",
ub_overlap_ag: bool = False,
name: Optional[str] = None,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_bulk_wgrad: bool = False,
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
checkpoint: bool = False,
) -> None:
I/O
- Input:
inp: torch.Tensorof shape[seq_length, batch_size, hidden_size]. - Output:
Union[torch.Tensor, Tuple[torch.Tensor, ...]]of shape[seq_length, batch_size, hidden_size]. Whenreturn_layernorm_output=True, returns a tuple containing both the MLP output and the intermediate LayerNorm output. Whenreturn_bias=True, the bias is returned separately.
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_size |
int |
(required) | Size of the input and output (model hidden dimension). |
ffn_hidden_size |
int |
(required) | Intermediate size of the MLP (typically 4x hidden_size).
|
eps |
float |
1e-5 |
Epsilon for numerical stability in LayerNorm. |
bias |
bool |
True |
Whether to add learnable biases to FC1 and FC2. |
normalization |
str |
"LayerNorm" |
Type of normalization: "LayerNorm" or "RMSNorm".
|
activation |
str |
"gelu" |
Activation function. Options: "gelu", "geglu", "silu", "swiglu", "relu", "srelu", "qgelu".
|
activation_params |
Optional[dict] |
None |
Additional parameters for the activation function. |
set_parallel_mode |
bool |
False |
If True, FC1 uses column parallelism and FC2 uses row parallelism.
|
return_layernorm_output |
bool |
False |
If True, also returns the intermediate LayerNorm output.
|
return_layernorm_output_gathered |
bool |
False |
If True (with sequence parallelism), returns the gathered LayerNorm output.
|
zero_centered_gamma |
bool |
False |
If True, the LayerNorm gamma parameter is centered at zero.
|
sequence_parallel |
bool |
False |
Whether to enable sequence parallelism. |
tp_group |
Optional[dist_group_type] |
None |
Tensor parallel process group. |
tp_size |
int |
1 |
Tensor parallel world size. |
fuse_wgrad_accumulation |
bool |
False |
Whether to fuse weight gradient accumulation into the WGRAD GEMM. |
delay_wgrad_compute |
bool |
False |
Whether to delay weight gradient computation. |
checkpoint |
bool |
False |
Whether to use selective activation checkpointing (recompute activations in backward, skipping FC2). |
symmetric_ar_type |
Optional[str] |
None |
Type of symmetric memory all-reduce for the forward pass. |
Example
Basic usage with GELU activation:
import transformer_engine.pytorch as te
mlp = te.LayerNormMLP(
hidden_size=1024,
ffn_hidden_size=4096,
activation="gelu",
normalization="LayerNorm",
)
output = mlp(hidden_states)
With SwiGLU gated activation (as used in LLaMA):
import transformer_engine.pytorch as te
mlp = te.LayerNormMLP(
hidden_size=1024,
ffn_hidden_size=4096,
activation="swiglu",
normalization="RMSNorm",
)
output = mlp(hidden_states)
With LayerNorm output returned for residual connections:
import transformer_engine.pytorch as te
mlp = te.LayerNormMLP(
hidden_size=1024,
ffn_hidden_size=4096,
activation="gelu",
return_layernorm_output=True,
)
mlp_output, ln_output = mlp(hidden_states)
Related
- Principle:NVIDIA_TransformerEngine_Fused_LayerNorm_MLP
- Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_Python_PyTorch_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_Sequence_Length_Alignment