Implementation:NVIDIA TransformerEngine Initialize UB
Overview
Concrete tool for initializing Userbuffer-based communication-GEMM overlap provided by TransformerEngine.
Description
initialize_ub(shape, tp_size, ...) allocates shared memory buffers for overlapping NCCL communication with GEMM operations. It creates CommOverlap and CommOverlapP2P objects stored in a global registry, retrieved by TE's Linear/LayerNormLinear during forward/backward.
The function performs the following steps:
- Allocates IPC-shared memory buffers of the specified shape for use by both NCCL and cuBLAS.
- Creates
CommOverlapobjects for collective operations (all-gather, reduce-scatter) andCommOverlapP2Pobjects for point-to-point ring-based communication. - Registers these objects in the global
_ub_communicatorsdictionary, keyed by operation name (e.g.,"qkv_fprop","proj_dgrad"). - Configures quantization modes for FP8 or BF16 communication as specified.
Once initialized, TE's linear modules automatically detect and use these communicators during forward and backward passes, transparently overlapping communication with GEMM computation.
Source
transformer_engine/pytorch/module/base.py, function initialize_ub at L97-446
Import
from transformer_engine.pytorch.module.base import initialize_ub
Signature
def initialize_ub(
shape: list,
tp_size: int,
use_fp8: bool = False,
quantization_modes: List[UserBufferQuantizationMode] = None,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[Union[dict, List[dict]]] = None,
bootstrap_backend: Union[str, torch.distributed.Backend] = None,
) -> None:
I/O
| Direction | Description |
|---|---|
| Input | shape (list): Buffer shape, typically [seq_len * batch_size, hidden_size]. tp_size (int): Tensor-parallel group size. Additional configuration for quantization and communication backend.
|
| Output | None. Populates the global _ub_communicators dictionary with CommOverlap and CommOverlapP2P objects.
|
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
shape |
list |
required | Shape of the communication buffers, typically [seq_len * batch_size, hidden_size]. Must match the activation tensor dimensions used in training.
|
tp_size |
int |
required | Size of the tensor-parallel group. Determines how buffers are partitioned for scatter/gather operations. |
use_fp8 |
bool |
False |
Whether to enable FP8 quantization for communication buffers. |
quantization_modes |
List[UserBufferQuantizationMode] |
None |
Specifies which operations use FP8 vs. BF16 communication. Provides fine-grained control over per-operation quantization. |
dtype |
torch.dtype |
torch.bfloat16 |
Data type for communication buffers when not using FP8. |
ub_cfgs |
Optional[Union[dict, List[dict]]] |
None |
Advanced configuration for userbuffer communicators, allowing per-operation tuning. |
bootstrap_backend |
Union[str, torch.distributed.Backend] |
None |
Backend used for bootstrapping the userbuffer communicators (e.g., "nccl" or "mpi").
|
Example Usage
import torch
from transformer_engine.pytorch.module.base import initialize_ub
# Initialize userbuffers for comm-GEMM overlap
# Buffer shape matches activation dimensions: [seq_len * batch_size, hidden_size]
initialize_ub(
shape=[2048 * 4, 4096], # seq_len=2048, batch_size=4, hidden_size=4096
tp_size=8,
use_fp8=True,
dtype=torch.bfloat16,
bootstrap_backend="nccl",
)
# After initialization, TE Linear modules automatically use
# the userbuffer communicators for overlapped communication.
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
tp_group=tp_group,
tp_size=8,
ub_overlap_ag=True, # Enable all-gather overlap
ub_overlap_rs=True, # Enable reduce-scatter overlap
)