Implementation:NVIDIA TransformerEngine Ops UB Backward Linear
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Distributed, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fused backward linear implementation that uses NVIDIA Userbuffers to overlap tensor-parallel communication with GEMM computation during the backward pass.
Description
Composes BasicLinear + optional Bias + optional ReduceScatter. Provides a _functional_backward that uses Userbuffers communicators (keyed by layer type like "qkv", "proj") to overlap all-gather or reduce-scatter with dgrad/wgrad GEMMs via CommOverlapType. Handles the full backward pass including gradient quantization, Megatron-LM wgrad accumulation, and bias gradient computation. The fuse_backward_ops method scans for ReduceScatter + Bias + BasicLinear or Bias + BasicLinear patterns where Userbuffers are configured.
Usage
Achieves communication-compute overlap in the backward pass for distributed training, hiding tensor-parallel communication latency behind GEMM computation. Pairs with UserbuffersForwardLinear.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py- Lines
- 1--661
Signature
class UserbuffersBackwardLinear(FusedOperation):
def __init__(self, *, linear, bias=None, reduce_scatter=None): ...
@staticmethod
def _functional_backward(
grad_output, input, weight, ...
): ...
@staticmethod
def fuse_backward_ops(ops: list[tuple[FusibleOperation, ...]]) -> list: ...
Import
from transformer_engine.pytorch.ops.fused import UserbuffersBackwardLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| grad_output | torch.Tensor |
Yes | Gradient from the next layer |
| input | torch.Tensor |
Yes | Saved input from forward pass |
| weight | torch.Tensor |
Yes | Weight parameter of the linear layer |
| linear | BasicLinear |
Yes | The basic linear operation to fuse |
| bias | Bias |
No | Optional bias operation |
| reduce_scatter | ReduceScatter |
No | Optional reduce-scatter for TP |
Outputs
| Name | Type | Description |
|---|---|---|
| grad_input | torch.Tensor |
Gradient w.r.t. input (data gradient) |
| grad_weight | torch.Tensor |
Gradient w.r.t. weight |
| grad_bias | torch.Tensor |
Gradient w.r.t. bias (if applicable) |
Usage Examples
# UserbuffersBackwardLinear is automatically discovered by the OperationFuser
# when Userbuffers are configured for a linear operation.
# It is not typically instantiated directly.
from transformer_engine.pytorch.ops import Sequential
from transformer_engine.pytorch.ops.basic import BasicLinear, Bias, ReduceScatter
# The fuser automatically detects and applies UB backward fusion
pipeline = Sequential(linear, bias, reduce_scatter)