Implementation:NVIDIA TransformerEngine Ops Base
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Defines the base class hierarchy for all fusible operations: the abstract base, basic operations, fused operations, and operation context for forward/backward state management.
Description
OperationContext is a dataclass that stores forward-pass state (saved tensors, gradient requirements) for use in the backward pass, with a save_for_backward API. FusibleOperation is the abstract base class inheriting from torch.nn.Module, defining the fuser_forward/fuser_backward interface and hooks for quantizer access (get_input_quantizer, get_grad_output_quantizer). BasicOperation extends FusibleOperation with op_forward/op_backward methods, recipe state management, parameter handling, and default fuser integration that delegates to these methods. FusedOperation combines multiple BasicOperations, exposing their parameters and providing the basic_ops list.
Usage
The foundational type system for the entire ops framework. Every operation (basic or fused) inherits from these classes. Defines the contract between operations and the fuser.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/ops/op.py- Lines
- 1--739
Signature
@dataclass
class OperationContext:
saved_tensors: tuple = ()
def save_for_backward(self, *tensors): ...
class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
def fuser_forward(self, ctx, input, ...): ...
def fuser_backward(self, ctx, grad_output): ...
def get_input_quantizer(self): ...
def get_grad_output_quantizer(self): ...
class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
def op_forward(self, ctx, input, ...): ...
def op_backward(self, ctx, grad_output): ...
class FusedOperation(FusibleOperation):
def __init__(self, *, basic_ops: Iterable[BasicOperation]): ...
@property
def basic_ops(self) -> list[BasicOperation]: ...
Import
from transformer_engine.pytorch.ops.op import (
OperationContext,
FusibleOperation,
BasicOperation,
FusedOperation,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | torch.Tensor |
Yes | Input tensor to the operation |
| ctx | OperationContext |
Yes | Context for saving state between forward and backward |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor |
Result of the operation |
| ctx | OperationContext |
Updated context with saved tensors for backward pass |
Usage Examples
from transformer_engine.pytorch.ops.op import BasicOperation, OperationContext
# Define a custom fusible operation
class MyOperation(BasicOperation):
def op_forward(self, ctx: OperationContext, input, **kwargs):
ctx.save_for_backward(input)
return input * 2
def op_backward(self, ctx: OperationContext, grad_output):
(input,) = ctx.saved_tensors
return grad_output * 2