Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime TorchCustomFunctionKernelBase

From Leeroopedia


Knowledge Sources
Domains Training, Operators, PyTorch_Interop
Last Updated 2026-02-10 04:00 GMT

Overview

Implements the PythonOpBase and PythonOpGradBase kernel classes that enable calling PyTorch autograd custom functions from within ORT training graphs.

Description

The `torch_custom_function_kernel_base.cc` file implements the bridge between ONNX Runtime training graphs and PyTorch's `autograd.Function` mechanism. It is only compiled when `ENABLE_TRAINING_TORCH_INTEROP` is defined. Key components:

  • PythonOpBase::Init: Extracts all operator attributes from the graph node:
 - `func_name`: Name of the Python autograd function.
 - `training_mode` / `safe_run_mode`: Execution mode flags.
 - `input_convention`: String describing the position/type of each input.
 - `input_requires_grads`: Gradient requirements per input.
 - `input_tensor_types`: ONNX tensor types for tensor inputs.
 - Scalar inputs: bool, int, float scalars with position indices.
 - Tuple inputs: bool, int, float tuples stored as concatenated arrays with begin indices.
 - `input_pointer_scalars`: Raw pointer scalars (e.g., Python objects).
 - `output_tensor_types` and `tensor_reuse_map`: Output configuration and in-place optimization.
 - Calls `CreateConstArgs` and `CreateArgPositions` to build the argument layout.
  • PythonOpBase::RunForward: Creates OrtValue arguments from the kernel context inputs, then delegates to `TorchProxy::GetInstance().Forward()` with all constant args, positions, and mode flags. Validates that the returned output count matches the kernel's expected count.
  • PythonOpBase::SetOutputs: Sets the autograd context as the first output (stored in `OrtTorchFunctionPool`) and copies returned OrtValues to subsequent outputs, verifying in-place tensor address matching.
  • Constant argument handling: `AddPrimitiveTypeScalarArgs`, `AddInputTupleArgs`, `AddFloatTupleArgs`, and `AddPointerScalarArgs` build Python objects (via CPython API: `PyBool_FromLong`, `Py_BuildValue`, `PyTuple_New`/`PyTuple_SetItem`) for non-tensor inputs. Pointer scalars are not owned (to avoid GIL issues during destruction).
  • PythonOpGradBase::Init: Similar attribute extraction for the backward function, with `output_convention` (character 'd' for tensor outputs) and `output_tensor_requires_grads`.
  • PythonOpGradBase::RunBackward: Retrieves the autograd context from the function pool using the context ID from the first input, then calls `TorchProxy::GetInstance().Backward()`. Unregisters the context after completion.
  • PythonOpGradBase::SetOutputs: Maps backward outputs to kernel outputs based on `output_convention`, verifying in-place tensor address matching and gradient requirements.

Usage

This is an internal implementation used by the ORT training graph executor when a training graph contains PythonOp/PythonOpGrad nodes that invoke PyTorch autograd custom functions.

Code Reference

Source Location

Signature

namespace onnxruntime::contrib {

void PythonOpBase::Init(const OpKernelInfo& info);
void PythonOpBase::Clear();
void PythonOpBase::RunForward(OpKernelContext* context, void** diff_ctx,
                              std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx,
                              std::vector<OrtValue>& returned_args) const;
void PythonOpBase::AddPrimitiveTypeScalarArgs();
void PythonOpBase::AddInputTupleArgs();
void PythonOpBase::AddFloatTupleArgs();
void PythonOpBase::AddPointerScalarArgs();
void PythonOpBase::CreateConstArgs();
void PythonOpBase::CreateArgPositions();
void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) const;
void PythonOpBase::SetOtherOutputs(OpKernelContext* context,
                                   std::vector<OrtValue>& returned_ortvalues) const;

void PythonOpGradBase::Init(const OpKernelInfo& info);
void PythonOpGradBase::RunBackward(OpKernelContext* context,
                                   std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpGradBase::SetOutputs(OpKernelContext* context,
                                  std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpGradBase::SetPositions();

}  // namespace onnxruntime::contrib

Import

#include "orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h"

I/O Contract

Method Inputs Outputs Description
PythonOpBase::Init OpKernelInfo (attributes) void Extracts all attributes and builds argument layout
PythonOpBase::RunForward OpKernelContext* diff_ctx (void**), returned_ortvalues Calls Python forward function via TorchProxy
PythonOpGradBase::RunBackward OpKernelContext* returned_ortvalues Calls Python backward function via TorchProxy
SetContextOutput OpKernelContext*, diff_ctx output[0] = context ID tensor Registers autograd context and outputs its ID
SetOtherOutputs OpKernelContext*, returned_ortvalues output[1..N] = OrtValues Sets remaining outputs with in-place verification

Usage Examples

// This is used internally by the ORT training executor.
// The forward pass for a PythonOp node:
//   1. Init() extracts function name, input types, scalars, tuples
//   2. RunForward() calls the Python autograd.Function.apply()
//   3. SetOutputs() stores the autograd context and returned tensors

// The backward pass for a PythonOpGrad node:
//   1. Init() extracts function name, output types, grad requirements
//   2. RunBackward() calls autograd.Function.backward(ctx, *grads)
//   3. SetOutputs() routes gradient tensors to appropriate graph outputs

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment