Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Ops Base

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch
Last Updated 2026-02-07 14:00 GMT

Overview

Defines the base class hierarchy for all fusible operations: the abstract base, basic operations, fused operations, and operation context for forward/backward state management.

Description

OperationContext is a dataclass that stores forward-pass state (saved tensors, gradient requirements) for use in the backward pass, with a save_for_backward API. FusibleOperation is the abstract base class inheriting from torch.nn.Module, defining the fuser_forward/fuser_backward interface and hooks for quantizer access (get_input_quantizer, get_grad_output_quantizer). BasicOperation extends FusibleOperation with op_forward/op_backward methods, recipe state management, parameter handling, and default fuser integration that delegates to these methods. FusedOperation combines multiple BasicOperations, exposing their parameters and providing the basic_ops list.

Usage

The foundational type system for the entire ops framework. Every operation (basic or fused) inherits from these classes. Defines the contract between operations and the fuser.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/ops/op.py
Lines
1--739

Signature

@dataclass
class OperationContext:
    saved_tensors: tuple = ()
    def save_for_backward(self, *tensors): ...

class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta):
    def fuser_forward(self, ctx, input, ...): ...
    def fuser_backward(self, ctx, grad_output): ...
    def get_input_quantizer(self): ...
    def get_grad_output_quantizer(self): ...

class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
    def op_forward(self, ctx, input, ...): ...
    def op_backward(self, ctx, grad_output): ...

class FusedOperation(FusibleOperation):
    def __init__(self, *, basic_ops: Iterable[BasicOperation]): ...
    @property
    def basic_ops(self) -> list[BasicOperation]: ...

Import

from transformer_engine.pytorch.ops.op import (
    OperationContext,
    FusibleOperation,
    BasicOperation,
    FusedOperation,
)

I/O Contract

Inputs

Name Type Required Description
input torch.Tensor Yes Input tensor to the operation
ctx OperationContext Yes Context for saving state between forward and backward

Outputs

Name Type Description
output torch.Tensor Result of the operation
ctx OperationContext Updated context with saved tensors for backward pass

Usage Examples

from transformer_engine.pytorch.ops.op import BasicOperation, OperationContext

# Define a custom fusible operation
class MyOperation(BasicOperation):
    def op_forward(self, ctx: OperationContext, input, **kwargs):
        ctx.save_for_backward(input)
        return input * 2

    def op_backward(self, ctx: OperationContext, grad_output):
        (input,) = ctx.saved_tensors
        return grad_output * 2

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment