Implementation:NVIDIA TransformerEngine TE TransformerLayer
Overview
Concrete tool for a complete optimized Transformer layer provided by TransformerEngine.
Description
te.TransformerLayer is the highest-level module in TransformerEngine, combining MultiheadAttention + LayerNormMLP + LayerNorm + residual connections + dropout into a single, fully optimized Transformer layer. It supports both encoder and decoder modes, FP8 quantization, tensor parallelism, sequence parallelism, context parallelism, and communication-GEMM overlap.
This module internally instantiates:
- A
MultiheadAttentionmodule for self-attention (with fused QKV projection). - Optionally a second
MultiheadAttentionmodule for cross-attention (in decoder mode). - A
LayerNormMLPmodule for the feed-forward network. - LayerNorm layers for normalization.
- Dropout layers and residual connection logic.
All sub-modules are configured consistently and dispatch to optimized CUDA kernels.
Source
- File:
transformer_engine/pytorch/transformer.py - Class:
TransformerLayerat lines L70-945 - Constructor:
__init__at lines L299-353
Import
from transformer_engine.pytorch import TransformerLayer
or equivalently:
import transformer_engine.pytorch as te
te.TransformerLayer
Signature
class TransformerLayer(torch.nn.Module):
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
num_gqa_groups: Optional[int] = None,
layernorm_epsilon: float = 1e-5,
hidden_dropout: float = 0.1,
attention_dropout: float = 0.1,
init_method: Optional[Callable] = None,
output_layer_init_method: Optional[Callable] = None,
layer_number: Optional[int] = None,
kv_channels: Optional[int] = None,
self_attn_mask_type: str = "causal",
window_size: Optional[Tuple[int, int]] = None,
bottom_right_diagonal: Optional[bool] = None,
enc_dec_attn_mask_type: str = "no_mask",
enc_dec_bottom_right_diagonal: Optional[bool] = None,
enc_dec_window_size: Optional[Tuple[int, int]] = None,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
params_dtype: Optional[torch.dtype] = None,
get_rng_state_tracker: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
sequence_parallel: bool = False,
apply_residual_connection_post_layernorm: bool = False,
output_layernorm: bool = False,
parallel_attention_mlp: bool = False,
layer_type: str = "encoder",
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
rotary_pos_interleaved: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
ub_overlap_ag: bool = True,
ub_overlap_rs: bool = True,
ub_overlap_rs_dgrad: bool = False,
ub_bulk_dgrad: bool = True,
ub_bulk_wgrad: bool = True,
bias: bool = True,
activation: str = "gelu",
activation_params: Optional[dict] = None,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
name: str = None,
qk_norm_type: Optional[str] = None,
qk_norm_eps: float = 1e-6,
qk_norm_before_rope: bool = False,
softmax_type: str = "vanilla",
) -> None:
I/O
- Input:
hidden_states: torch.Tensorof shape[s, b, h](sequence, batch, hidden_size) -- the primary input.attention_mask: Optional[torch.Tensor]-- attention mask (used with"padding"or"arbitrary"mask types).encoder_output: Optional[torch.Tensor]-- encoder hidden states for cross-attention (decoder mode only).rotary_pos_emb: Optional[torch.Tensor]-- rotary positional embedding tensors.inference_params: Optional-- parameters for inference-time KV caching.
- Output:
torch.Tensorof shape[s, b, h]-- the Transformer layer output.
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_size |
int |
(required) | Size of hidden representations (model dimension). |
ffn_hidden_size |
int |
(required) | Intermediate size of the MLP (typically 4x hidden_size).
|
num_attention_heads |
int |
(required) | Number of attention heads. |
num_gqa_groups |
Optional[int] |
None |
Number of GQA groups. None defaults to MHA. Set to 1 for MQA.
|
layernorm_epsilon |
float |
1e-5 |
Epsilon for LayerNorm numerical stability. |
hidden_dropout |
float |
0.1 |
Dropout probability after FC2 and output projection. |
attention_dropout |
float |
0.1 |
Dropout probability in attention softmax. |
self_attn_mask_type |
str |
"causal" |
Attention mask type: "no_mask", "padding", "causal", "padding_causal", "causal_bottom_right", "arbitrary".
|
activation |
str |
"gelu" |
MLP activation: "gelu", "geglu", "silu", "swiglu", "relu", "srelu", "qgelu".
|
normalization |
str |
"LayerNorm" |
Normalization type: "LayerNorm" or "RMSNorm".
|
layer_type |
str |
"encoder" |
Layer type: "encoder" (self-attention only) or "decoder" (self-attention + cross-attention).
|
fuse_qkv_params |
bool |
False |
Whether to fuse Q, K, V weight matrices into a single parameter. |
parallel_attention_mlp |
bool |
False |
If True, attention and MLP run in parallel (Falcon architecture).
|
apply_residual_connection_post_layernorm |
bool |
False |
If True, residual is taken from LayerNorm output rather than input.
|
output_layernorm |
bool |
False |
If True, apply LayerNorm on the output side instead of input side.
|
zero_centered_gamma |
bool |
False |
If True, LayerNorm gamma is centered at zero.
|
attn_input_format |
str |
"sbhd" |
Attention input format: "sbhd" or "bshd".
|
set_parallel_mode |
bool |
False |
If True, enables tensor parallelism for all sub-modules.
|
sequence_parallel |
bool |
False |
Whether to enable sequence parallelism. |
tp_group |
Optional[dist_group_type] |
None |
Tensor parallel process group. |
tp_size |
int |
1 |
Tensor parallel world size. |
Communication-GEMM overlap parameters (require ub_tp_comm_overlap=True):
| Parameter | Default | Description |
|---|---|---|
ub_tp_comm_overlap |
False |
Master switch to enable communication-GEMM overlap. |
ub_overlap_ag |
True |
Overlap all-gather with GEMM in forward pass. |
ub_overlap_rs |
True |
Overlap reduce-scatter with GEMM in forward pass. |
ub_overlap_rs_dgrad |
False |
Overlap reduce-scatter with DGRAD GEMM in backward pass. |
ub_bulk_dgrad |
True |
Bulk overlap for DGRAD communication. |
ub_bulk_wgrad |
True |
Bulk overlap for WGRAD communication. |
Example
Basic causal decoder layer (GPT-style):
import transformer_engine.pytorch as te
layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
self_attn_mask_type="causal",
activation="gelu",
normalization="LayerNorm",
)
output = layer(hidden_states)
With grouped query attention (GQA) and SwiGLU (LLaMA-style):
import transformer_engine.pytorch as te
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
num_gqa_groups=8,
self_attn_mask_type="causal",
activation="swiglu",
normalization="RMSNorm",
bias=False,
)
output = layer(hidden_states)
Encoder-decoder (T5-style) with cross-attention:
import transformer_engine.pytorch as te
decoder_layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
layer_type="decoder",
self_attn_mask_type="causal",
)
output = decoder_layer(
hidden_states,
encoder_output=encoder_hidden_states,
)
With tensor parallelism and communication overlap:
import transformer_engine.pytorch as te
layer = te.TransformerLayer(
hidden_size=1024,
ffn_hidden_size=4096,
num_attention_heads=16,
set_parallel_mode=True,
tp_group=tp_group,
tp_size=8,
sequence_parallel=True,
ub_tp_comm_overlap=True,
)
output = layer(hidden_states)
Related
- Principle:NVIDIA_TransformerEngine_Complete_Transformer_Layer
- Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
- Environment:NVIDIA_TransformerEngine_Python_PyTorch_Requirements
- Environment:NVIDIA_TransformerEngine_GPU_Compute_Capability
- Heuristic:NVIDIA_TransformerEngine_Build_Optimization_Tips