Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine Initialize UB

From Leeroopedia


Overview

Concrete tool for initializing Userbuffer-based communication-GEMM overlap provided by TransformerEngine.

Description

initialize_ub(shape, tp_size, ...) allocates shared memory buffers for overlapping NCCL communication with GEMM operations. It creates CommOverlap and CommOverlapP2P objects stored in a global registry, retrieved by TE's Linear/LayerNormLinear during forward/backward.

The function performs the following steps:

  • Allocates IPC-shared memory buffers of the specified shape for use by both NCCL and cuBLAS.
  • Creates CommOverlap objects for collective operations (all-gather, reduce-scatter) and CommOverlapP2P objects for point-to-point ring-based communication.
  • Registers these objects in the global _ub_communicators dictionary, keyed by operation name (e.g., "qkv_fprop", "proj_dgrad").
  • Configures quantization modes for FP8 or BF16 communication as specified.

Once initialized, TE's linear modules automatically detect and use these communicators during forward and backward passes, transparently overlapping communication with GEMM computation.

Source

transformer_engine/pytorch/module/base.py, function initialize_ub at L97-446

Import

from transformer_engine.pytorch.module.base import initialize_ub

Signature

def initialize_ub(
    shape: list,
    tp_size: int,
    use_fp8: bool = False,
    quantization_modes: List[UserBufferQuantizationMode] = None,
    dtype: torch.dtype = torch.bfloat16,
    ub_cfgs: Optional[Union[dict, List[dict]]] = None,
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:

I/O

Direction Description
Input shape (list): Buffer shape, typically [seq_len * batch_size, hidden_size]. tp_size (int): Tensor-parallel group size. Additional configuration for quantization and communication backend.
Output None. Populates the global _ub_communicators dictionary with CommOverlap and CommOverlapP2P objects.

Key Parameters

Parameter Type Default Description
shape list required Shape of the communication buffers, typically [seq_len * batch_size, hidden_size]. Must match the activation tensor dimensions used in training.
tp_size int required Size of the tensor-parallel group. Determines how buffers are partitioned for scatter/gather operations.
use_fp8 bool False Whether to enable FP8 quantization for communication buffers.
quantization_modes List[UserBufferQuantizationMode] None Specifies which operations use FP8 vs. BF16 communication. Provides fine-grained control over per-operation quantization.
dtype torch.dtype torch.bfloat16 Data type for communication buffers when not using FP8.
ub_cfgs Optional[Union[dict, List[dict]]] None Advanced configuration for userbuffer communicators, allowing per-operation tuning.
bootstrap_backend Union[str, torch.distributed.Backend] None Backend used for bootstrapping the userbuffer communicators (e.g., "nccl" or "mpi").

Example Usage

import torch
from transformer_engine.pytorch.module.base import initialize_ub

# Initialize userbuffers for comm-GEMM overlap
# Buffer shape matches activation dimensions: [seq_len * batch_size, hidden_size]
initialize_ub(
    shape=[2048 * 4, 4096],  # seq_len=2048, batch_size=4, hidden_size=4096
    tp_size=8,
    use_fp8=True,
    dtype=torch.bfloat16,
    bootstrap_backend="nccl",
)

# After initialization, TE Linear modules automatically use
# the userbuffer communicators for overlapped communication.
layer = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=11008,
    num_attention_heads=32,
    tp_group=tp_group,
    tp_size=8,
    ub_overlap_ag=True,   # Enable all-gather overlap
    ub_overlap_rs=True,   # Enable reduce-scatter overlap
)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment