Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Baseline Model

From Leeroopedia


Overview

Reference pure-PyTorch Transformer implementation used as a performance baseline in TransformerEngine examples.

Description

PyTorchTransformerLayer and PyTorchMLP are plain PyTorch implementations of a Transformer decoder layer. They serve as the "Step 0" baseline in TransformerEngine's getting-started tutorial, demonstrating the unoptimized starting point before TE's progressive optimization path is applied.

This is a Pattern Doc -- it documents user-defined reference code from the TE examples, not a library API. The code is intended to be replaced by TE modules in subsequent optimization steps.

The baseline uses only standard PyTorch modules:

  • torch.nn.LayerNorm for normalization.
  • torch.nn.Linear for all projections (QKV, output, FC1, FC2).
  • torch.nn.functional.gelu with approximate="tanh" for the MLP activation.
  • torch.nn.Dropout for regularization.
  • Manual dot-product attention with DotProductAttention helper.

Source

  • File: docs/getting_started/getting_started_pytorch.py
  • Class: PyTorchMLP at lines L39-58
  • Class: PyTorchTransformerLayer at lines L65-116

This is a Pattern Doc

This implementation is not a library API. It is user-defined reference code from the TransformerEngine getting-started tutorial. It demonstrates how a standard Transformer layer looks in pure PyTorch before any TE optimizations are applied.

Signature

PyTorchMLP:

class PyTorchMLP(torch.nn.Module):
    """Feed-forward network in Transformer layer.
    Built with plain PyTorch modules.
    """

    hidden_size: int
    ffn_hidden_size: int

    def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
        self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = torch.nn.functional.gelu(x, approximate="tanh")
        x = self.linear2(x)
        return x

PyTorchTransformerLayer:

class PyTorchTransformerLayer(torch.nn.Module):
    """Basic Transformer layer using plain PyTorch modules."""

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: float = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads
        self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.attention = DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
        )
        self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
        self.dropout = torch.nn.Dropout(hidden_dropout)
        self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)

I/O

  • Input: x: torch.Tensor of shape [seq_length, batch_size, hidden_size].
  • Output: torch.Tensor of shape [seq_length, batch_size, hidden_size].

Key Parameters

PyTorchTransformerLayer Parameters
Parameter Type Default Description
hidden_size int (required) Size of hidden representations (model dimension).
ffn_hidden_size int (required) Intermediate size of the MLP.
num_attention_heads int (required) Number of attention heads.
layernorm_eps float 1e-5 Epsilon for LayerNorm.
attention_dropout float 0.1 Dropout probability in attention.
hidden_dropout float 0.1 Dropout probability after output projection.
PyTorchMLP Parameters
Parameter Type Default Description
hidden_size int (required) Input and output size (model dimension).
ffn_hidden_size int (required) Intermediate size (typically 4x hidden_size).

Forward Pass Details

The PyTorchTransformerLayer.forward method implements the standard pre-norm Transformer layer:

def forward(
    self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    res = x
    x = self.ln1(x)

    # Fused QKV projection
    qkv = self.qkv_projection(x)
    qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
    q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

    x = self.attention(q, k, v, attention_mask)
    x = self.projection(x)
    x = self.dropout(x)
    x = res + x

    # Second residual connection
    res = x
    x = self.ln2(x)
    x = self.mlp(x)

    return x + res

Example

Instantiating and benchmarking the baseline:

import torch

baseline = (
    PyTorchTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

# Forward pass
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
output = baseline(x)

Internal Module Composition

Modules within PyTorchTransformerLayer
Attribute Module Type Purpose
self.ln1 torch.nn.LayerNorm Pre-attention normalization
self.qkv_projection torch.nn.Linear Projects input to Q, K, V (3x hidden_size)
self.attention DotProductAttention Scaled dot-product attention with dropout
self.projection torch.nn.Linear Output projection after attention
self.dropout torch.nn.Dropout Dropout after output projection
self.ln2 torch.nn.LayerNorm Pre-MLP normalization
self.mlp PyTorchMLP Feed-forward network (FC1 + GELU + FC2)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment