Implementation:NVIDIA TransformerEngine Ops Fused Forward Linear Bias Activation
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Fused forward pass operation that combines GEMM, bias addition, and (planned) activation into a single cuBLAS call for FP16/BF16 output.
Description
ForwardLinearBiasActivation is a FusedOperation that fuses the forward pass of BasicLinear, Bias, and (future) activation into a single cuBLAS GEMM+bias call. The bias is passed directly to BasicLinear._functional_forward. Activation fusion is planned but not yet implemented. The fusion requires that the linear operation is not using row tensor parallelism (which requires communication after GEMM) and that the weight dtype is FP16 or BF16 (cuBLAS limitation).
Usage
Automatically applied by the operation fuser when it detects a [BasicLinear, Bias] pattern in the forward pass with compatible tensor parallelism and dtype.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py- Lines
- 1--196
Signature
class ForwardLinearBiasActivation(FusedOperation):
def __init__(self, *, linear: BasicLinear, bias: Optional[Bias], activation: None): ...
def fuser_forward(self, basic_op_ctxs, input_, *, basic_op_extra_inputs, prev_op_grad_output_quantizer, next_op_input_quantizer, basic_op_kwargs) -> Tuple: ...
@staticmethod
def fuse_forward_ops(ops, **unused) -> list[FusibleOperation]: ...
Import
from transformer_engine.pytorch.ops.fused.forward_linear_bias_activation import ForwardLinearBiasActivation
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| linear | BasicLinear | Yes | The linear operation (not row TP, FP16/BF16 weight) |
| bias | Optional[Bias] | No | Optional bias operation |
| activation | None | No | Activation (currently not supported) |
| input_ | torch.Tensor | Yes | Input tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of GEMM + bias (+ activation when supported) |
Usage Examples
# Automatically fused by the operation fuser when detecting pattern:
# [BasicLinear, Bias] in the forward pass with FP16/BF16 weight
# No manual usage required