Implementation:Microsoft Onnxruntime LazyTensorAccelerator
| Knowledge Sources | |
|---|---|
| Domains | Training, LazyTensor, Optimization |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Implements the Accelerator class for the ORT Lazy Tensor integration, which compiles and executes PyTorch JIT subgraphs through ONNX Runtime with caching, ONNX export, and multiple execution modes.
Description
The `accelerator.cc` file implements the core execution engine for the ORT Lazy Tensor system, which intercepts PyTorch operations and runs them through ONNX Runtime for acceleration. Key components:
- IsFusable: Validates that a JIT node's inputs meet fusion requirements (tensors, non-constant floats/ints, or constants).
- Accelerator::Supported: Checks if a given ATen operation is supported for ORT acceleration. Currently supports ~30 operations including `embedding`, `tanh`, `gelu`, `native_layer_norm`, `native_dropout`, `convolution`, `relu`, `mm`, `bmm`, `add`, `mul`, `sub`, `div`, `reshape`, `permute`, `max_pool2d_with_indices`, `_log_softmax`, comparison ops, and more.
- Accelerator::Run: Dispatches execution based on `RunType()`: "ort" (ORT only), "pytorch" (PyTorch only), or "debug" (both, with result comparison).
- Accelerator::OrtRun: Executes the subgraph through ORT. Retrieves inputs from the JIT stack, looks up or creates a compiled callable based on `CompleteArgumentSpec` caching, runs it, and pushes outputs back to the stack.
- Accelerator::PytorchRun: Falls back to PyTorch's `GraphExecutor` for execution, disabling ONNX fusion during the run.
- Accelerator::DebugRun: Runs both ORT and PyTorch, then compares results using configurable tolerances (`LORT_RELATIVE_TOLERANCE`, `LORT_ABSOLUTE_TOLERANCE`).
- Accelerator::Compile: The compilation pipeline:
1. Runs an example through PyTorch to capture input/output types. 2. Creates an ORT InferenceSession. 3. Exports the subgraph to ONNX via Python's `onnxruntime.training.experimental.exporter`. 4. Detects the shared device (CPU or specific GPU). 5. Initializes the session with appropriate EPs (including CUDA if available). 6. Creates a lambda that converts between PyTorch IValues and ORT OrtValues for execution.
- Helper functions: `CheckArgs`, `SetArgTypes`, `ExportToOnnx`, `CreateSession`, `CheckAndGetTensorDevice`, `InitializeSession`.
Usage
This module is used internally by the ORT Lazy Tensor backend to accelerate PyTorch model execution through ONNX Runtime. It is activated when the LORT (Lazy ORT) backend is selected in PyTorch.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/lazy_tensor/accelerator.cc
- Lines: 1-515
Signature
namespace onnxruntime::lazytensor {
bool IsFusable(const torch::jit::Node* node);
// Accelerator methods
bool Accelerator::Supported(const torch::jit::Node* node);
void Accelerator::OrtRun(torch::jit::Stack& stack);
void Accelerator::PytorchRun(torch::jit::Stack& stack);
void Accelerator::DebugRun(torch::jit::Stack& stack);
void Accelerator::Run(torch::jit::Stack& stack);
void Accelerator::ExampleRun(at::ArrayRef<c10::IValue> inputs);
CompiledObject Accelerator::Compile(torch::jit::CompleteArgumentSpec spec,
at::ArrayRef<c10::IValue>& args);
} // namespace onnxruntime::lazytensor
Import
#include "orttraining/lazy_tensor/accelerator.h"
I/O Contract
| Method | Inputs | Outputs | Description |
|---|---|---|---|
| Supported | torch::jit::Node* | bool | Returns whether the ATen op is supported for ORT acceleration |
| Run | torch::jit::Stack& | modified Stack | Executes subgraph via configured mode (ort/pytorch/debug) |
| Compile | CompleteArgumentSpec, IValue args | CompiledObject | Compiles subgraph: exports to ONNX, creates ORT session, returns callable |
| OrtRun | torch::jit::Stack& | modified Stack | Executes compiled subgraph through ORT InferenceSession |
| PytorchRun | torch::jit::Stack& | modified Stack | Falls back to PyTorch GraphExecutor |
Usage Examples
// Internal usage within the lazy tensor framework:
Accelerator acc(subgraph);
// Check op support
if (Accelerator::Supported(node)) {
// Node can be included in ORT fusion group
}
// Execute
torch::jit::Stack stack;
stack.push_back(input_tensor);
acc.Run(stack); // Dispatches to ORT, PyTorch, or debug mode