Implementation:NVIDIA TransformerEngine Ops AddExtraInput
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fusible operation that adds an extra input tensor to the main input, supporting both in-place and out-of-place modes within the operation fuser framework.
Description
AddExtraInput is a BasicOperation that accepts one extra tensor input through the operation fuser. It returns the sum of the main input and the extra input. When in_place=True, the addition is performed in-place on the extra input tensor, which can enable fusion optimizations (e.g., ForwardLinearBiasAdd). This operation is the forward-pass counterpart to MakeExtraOutput, which provides similar functionality in the backward pass.
Usage
Used within the operation fuser pipeline for residual connections and skip connections. The in-place mode is an advanced feature for enabling specific fused operation patterns.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/basic/add_extra_input.py- Lines
- 1--91
Signature
class AddExtraInput(BasicOperation):
num_extra_inputs: int = 1
def __init__(self, *, in_place: bool = False): ...
def fuser_forward(self, basic_op_ctxs, input_, *, basic_op_extra_inputs, prev_op_grad_output_quantizer, next_op_input_quantizer, basic_op_kwargs) -> Tuple[torch.Tensor, Iterable]: ...
def fuser_backward(self, basic_op_ctxs, grad_output, *, basic_op_grad_extra_outputs) -> Tuple[torch.Tensor, Iterable, Iterable]: ...
Import
from transformer_engine.pytorch.ops.basic.add_extra_input import AddExtraInput
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_ | torch.Tensor | Yes | Main input tensor |
| extra_input | torch.Tensor | Yes | Extra tensor to add (provided through fuser) |
| in_place | bool | No | Whether to perform in-place addition on the extra input |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Sum of main input and extra input |
Usage Examples
from transformer_engine.pytorch.ops.basic.add_extra_input import AddExtraInput
from transformer_engine.pytorch.ops import Sequential
# Residual connection with in-place addition
model = Sequential(
linear_op,
AddExtraInput(in_place=True),
)
output = model(input_tensor, residual_tensor)