Implementation:NVIDIA TransformerEngine PyTorch Baseline Model
Overview
Reference pure-PyTorch Transformer implementation used as a performance baseline in TransformerEngine examples.
Description
PyTorchTransformerLayer and PyTorchMLP are plain PyTorch implementations of a Transformer decoder layer. They serve as the "Step 0" baseline in TransformerEngine's getting-started tutorial, demonstrating the unoptimized starting point before TE's progressive optimization path is applied.
This is a Pattern Doc -- it documents user-defined reference code from the TE examples, not a library API. The code is intended to be replaced by TE modules in subsequent optimization steps.
The baseline uses only standard PyTorch modules:
torch.nn.LayerNormfor normalization.torch.nn.Linearfor all projections (QKV, output, FC1, FC2).torch.nn.functional.geluwithapproximate="tanh"for the MLP activation.torch.nn.Dropoutfor regularization.- Manual dot-product attention with
DotProductAttentionhelper.
Source
- File:
docs/getting_started/getting_started_pytorch.py - Class:
PyTorchMLPat lines L39-58 - Class:
PyTorchTransformerLayerat lines L65-116
This is a Pattern Doc
This implementation is not a library API. It is user-defined reference code from the TransformerEngine getting-started tutorial. It demonstrates how a standard Transformer layer looks in pure PyTorch before any TE optimizations are applied.
Signature
PyTorchMLP:
class PyTorchMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain PyTorch modules.
"""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
PyTorchTransformerLayer:
class PyTorchTransformerLayer(torch.nn.Module):
"""Basic Transformer layer using plain PyTorch modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = torch.nn.Dropout(hidden_dropout)
self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
I/O
- Input:
x: torch.Tensorof shape[seq_length, batch_size, hidden_size]. - Output:
torch.Tensorof shape[seq_length, batch_size, hidden_size].
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_size |
int |
(required) | Size of hidden representations (model dimension). |
ffn_hidden_size |
int |
(required) | Intermediate size of the MLP. |
num_attention_heads |
int |
(required) | Number of attention heads. |
layernorm_eps |
float |
1e-5 |
Epsilon for LayerNorm. |
attention_dropout |
float |
0.1 |
Dropout probability in attention. |
hidden_dropout |
float |
0.1 |
Dropout probability after output projection. |
| Parameter | Type | Default | Description |
|---|---|---|---|
hidden_size |
int |
(required) | Input and output size (model dimension). |
ffn_hidden_size |
int |
(required) | Intermediate size (typically 4x hidden_size).
|
Forward Pass Details
The PyTorchTransformerLayer.forward method implements the standard pre-norm Transformer layer:
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
Example
Instantiating and benchmarking the baseline:
import torch
baseline = (
PyTorchTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
# Forward pass
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
output = baseline(x)
Internal Module Composition
| Attribute | Module Type | Purpose |
|---|---|---|
self.ln1 |
torch.nn.LayerNorm |
Pre-attention normalization |
self.qkv_projection |
torch.nn.Linear |
Projects input to Q, K, V (3x hidden_size) |
self.attention |
DotProductAttention |
Scaled dot-product attention with dropout |
self.projection |
torch.nn.Linear |
Output projection after attention |
self.dropout |
torch.nn.Dropout |
Dropout after output projection |
self.ln2 |
torch.nn.LayerNorm |
Pre-MLP normalization |
self.mlp |
PyTorchMLP |
Feed-forward network (FC1 + GELU + FC2) |