Implementation:Microsoft Onnxruntime TorchCustomFunctionKernelBase
| Knowledge Sources | |
|---|---|
| Domains | Training, Operators, PyTorch_Interop |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Implements the PythonOpBase and PythonOpGradBase kernel classes that enable calling PyTorch autograd custom functions from within ORT training graphs.
Description
The `torch_custom_function_kernel_base.cc` file implements the bridge between ONNX Runtime training graphs and PyTorch's `autograd.Function` mechanism. It is only compiled when `ENABLE_TRAINING_TORCH_INTEROP` is defined. Key components:
- PythonOpBase::Init: Extracts all operator attributes from the graph node:
- `func_name`: Name of the Python autograd function. - `training_mode` / `safe_run_mode`: Execution mode flags. - `input_convention`: String describing the position/type of each input. - `input_requires_grads`: Gradient requirements per input. - `input_tensor_types`: ONNX tensor types for tensor inputs. - Scalar inputs: bool, int, float scalars with position indices. - Tuple inputs: bool, int, float tuples stored as concatenated arrays with begin indices. - `input_pointer_scalars`: Raw pointer scalars (e.g., Python objects). - `output_tensor_types` and `tensor_reuse_map`: Output configuration and in-place optimization. - Calls `CreateConstArgs` and `CreateArgPositions` to build the argument layout.
- PythonOpBase::RunForward: Creates OrtValue arguments from the kernel context inputs, then delegates to `TorchProxy::GetInstance().Forward()` with all constant args, positions, and mode flags. Validates that the returned output count matches the kernel's expected count.
- PythonOpBase::SetOutputs: Sets the autograd context as the first output (stored in `OrtTorchFunctionPool`) and copies returned OrtValues to subsequent outputs, verifying in-place tensor address matching.
- Constant argument handling: `AddPrimitiveTypeScalarArgs`, `AddInputTupleArgs`, `AddFloatTupleArgs`, and `AddPointerScalarArgs` build Python objects (via CPython API: `PyBool_FromLong`, `Py_BuildValue`, `PyTuple_New`/`PyTuple_SetItem`) for non-tensor inputs. Pointer scalars are not owned (to avoid GIL issues during destruction).
- PythonOpGradBase::Init: Similar attribute extraction for the backward function, with `output_convention` (character 'd' for tensor outputs) and `output_tensor_requires_grads`.
- PythonOpGradBase::RunBackward: Retrieves the autograd context from the function pool using the context ID from the first input, then calls `TorchProxy::GetInstance().Backward()`. Unregisters the context after completion.
- PythonOpGradBase::SetOutputs: Maps backward outputs to kernel outputs based on `output_convention`, verifying in-place tensor address matching and gradient requirements.
Usage
This is an internal implementation used by the ORT training graph executor when a training graph contains PythonOp/PythonOpGrad nodes that invoke PyTorch autograd custom functions.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc
- Lines: 1-447
Signature
namespace onnxruntime::contrib {
void PythonOpBase::Init(const OpKernelInfo& info);
void PythonOpBase::Clear();
void PythonOpBase::RunForward(OpKernelContext* context, void** diff_ctx,
std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx,
std::vector<OrtValue>& returned_args) const;
void PythonOpBase::AddPrimitiveTypeScalarArgs();
void PythonOpBase::AddInputTupleArgs();
void PythonOpBase::AddFloatTupleArgs();
void PythonOpBase::AddPointerScalarArgs();
void PythonOpBase::CreateConstArgs();
void PythonOpBase::CreateArgPositions();
void PythonOpBase::SetContextOutput(OpKernelContext* context, void* diff_ctx) const;
void PythonOpBase::SetOtherOutputs(OpKernelContext* context,
std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpGradBase::Init(const OpKernelInfo& info);
void PythonOpGradBase::RunBackward(OpKernelContext* context,
std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpGradBase::SetOutputs(OpKernelContext* context,
std::vector<OrtValue>& returned_ortvalues) const;
void PythonOpGradBase::SetPositions();
} // namespace onnxruntime::contrib
Import
#include "orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h"
I/O Contract
| Method | Inputs | Outputs | Description |
|---|---|---|---|
| PythonOpBase::Init | OpKernelInfo (attributes) | void | Extracts all attributes and builds argument layout |
| PythonOpBase::RunForward | OpKernelContext* | diff_ctx (void**), returned_ortvalues | Calls Python forward function via TorchProxy |
| PythonOpGradBase::RunBackward | OpKernelContext* | returned_ortvalues | Calls Python backward function via TorchProxy |
| SetContextOutput | OpKernelContext*, diff_ctx | output[0] = context ID tensor | Registers autograd context and outputs its ID |
| SetOtherOutputs | OpKernelContext*, returned_ortvalues | output[1..N] = OrtValues | Sets remaining outputs with in-place verification |
Usage Examples
// This is used internally by the ORT training executor.
// The forward pass for a PythonOp node:
// 1. Init() extracts function name, input types, scalars, tuples
// 2. RunForward() calls the Python autograd.Function.apply()
// 3. SetOutputs() stores the autograd context and returned tensors
// The backward pass for a PythonOpGrad node:
// 1. Init() extracts function name, output types, grad requirements
// 2. RunBackward() calls autograd.Function.backward(ctx, *grads)
// 3. SetOutputs() routes gradient tensors to appropriate graph outputs