Implementation:NVIDIA TransformerEngine Ops Fused Backward Linear Add
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fused backward pass operation that combines dgrad GEMM with residual gradient accumulation, enabling overlap of the GEMM output with the residual gradient addition.
Description
BackwardLinearAdd is a FusedOperation that fuses the backward pass of MakeExtraOutput (residual) and BasicLinear by accumulating the dgrad GEMM output directly into the residual gradient buffer. This eliminates a separate addition kernel. The fusion requires an in-place MakeExtraOutput and a BasicLinear without column tensor parallelism (which requires communication after dgrad). It supports Megatron-LM wgrad fusion with accumulate_into_main_grad.
Usage
Automatically applied by the operation fuser in backward pass pattern matching. Column tensor parallelism is not supported.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/fused/backward_linear_add.py- Lines
- 1--168
Signature
class BackwardLinearAdd(FusedOperation):
def __init__(self, *, backward_add: MakeExtraOutput, linear: BasicLinear): ...
def fuser_backward(self, basic_op_ctxs, grad_output, *, basic_op_grad_extra_outputs) -> Tuple: ...
@staticmethod
def fuse_backward_ops(ops, **unused) -> list[FusibleOperation]: ...
Import
from transformer_engine.pytorch.ops.fused.backward_linear_add import BackwardLinearAdd
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| backward_add | MakeExtraOutput | Yes | In-place MakeExtraOutput operation |
| linear | BasicLinear | Yes | BasicLinear operation (not column TP) |
| grad_output | torch.Tensor | Yes | Upstream gradient |
Outputs
| Name | Type | Description |
|---|---|---|
| grad_input | torch.Tensor | dgrad accumulated into residual gradient buffer |
| grad_weight | torch.Tensor | Weight gradient (may be accumulated into main_grad) |
Usage Examples
# Automatically fused by the operation fuser when detecting pattern:
# [MakeExtraOutput(in_place=True), BasicLinear] in the backward pass
# No manual usage required