Implementation:Microsoft Onnxruntime LazyTensorFusion
| Knowledge Sources | |
|---|---|
| Domains | Training, LazyTensor, GraphOptimization |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Implements the OrtFuser graph pass that merges fusable PyTorch JIT nodes into fusion groups for execution by the ORT Lazy Tensor Accelerator.
Description
The `fusion.cc` file implements a custom graph fusion pass (`OrtFuser`) for the PyTorch JIT IR, used by the ORT Lazy Tensor system. The fuser identifies operations that can be accelerated by ORT and merges them into subgraph-based fusion groups. Key components:
- OrtFuser struct: The core fusion engine operating on a JIT `Block`. It uses `AliasDb` for safety analysis and a `FusionCallback` to determine fusability. Configuration includes the fusion group `Symbol` kind and a `subgraph_arg_limit_` (default 128) to prevent overly large fusion groups.
- mergeNodeIntoGroup: Inserts a producer node into an existing fusion group's subgraph. Handles tensor inputs (added to group inputs), non-constant scalar inputs (floats, ints), constant inputs (inlined as constants in the subgraph), and special cases like `_grad_sum_to_size` with int list inputs. Remaps all internal connections.
- createSingletonFusionGroup: Creates a new fusion group containing a single node, establishing the initial subgraph with proper input/output connections and alias tracking.
- mergeFusionGroups: Merges two fusion groups by extracting all nodes from the producer group into the outer graph, destroying the producer, then inlining the temporary nodes into the consumer group.
- tryFuse: Attempts to fuse a producer into a consumer. Checks fusability, topological ordering via `AliasDb::moveBeforeTopologicallyValid`, and argument count limits. Creates singleton groups as needed.
- scanNode / run: The main fusion loop iterates in reverse topological order. The `run()` method executes repeatedly until no more fusions are found (needed for cases where multi-step fusion unlocks additional opportunities). After fusion, `optimizeFusedGraphs` applies dead code elimination, CSE, and constant pooling to each fused subgraph.
- OrtFuseGraph: The public entry point that creates an `AliasDb`, instantiates the `OrtFuser`, runs the fusion pass, and validates with `torch::jit::Lint`.
Usage
This module is called by the ORT Lazy Tensor infrastructure to fuse supported operations in PyTorch JIT graphs before executing them through ONNX Runtime.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/lazy_tensor/fusion.cc
- Lines: 1-405
Signature
namespace onnxruntime::lazytensor {
struct OrtFuser {
using FusionCallback = std::function<bool(OrtFuser*, torch::jit::Node*)>;
OrtFuser(torch::jit::AliasDb* aliasDb, torch::jit::Block* block,
FusionCallback callback, torch::jit::Symbol kind,
bool strict_fuser_check = false, size_t subgraph_arg_limit = 128);
bool isFusable(torch::jit::Node* node);
torch::jit::Node* mergeNodeIntoGroup(torch::jit::Node* group, torch::jit::Node* n);
torch::jit::Node* createSingletonFusionGroup(torch::jit::Node* n);
void mergeFusionGroups(torch::jit::Node* consumer_group, torch::jit::Node* producer_group);
at::optional<torch::jit::Node*> tryFuse(torch::jit::Node* consumer, torch::jit::Value* producer);
void run();
};
void OrtFuseGraph(std::shared_ptr<torch::jit::Graph>& graph,
const std::function<bool(torch::jit::Node*)>& fn,
torch::jit::Symbol kind, size_t arg_limit);
} // namespace onnxruntime::lazytensor
Import
#include "orttraining/lazy_tensor/fusion.h"
I/O Contract
| Function | Inputs | Outputs | Description |
|---|---|---|---|
| OrtFuseGraph | shared_ptr<Graph>, predicate fn, Symbol kind, arg_limit | void (graph modified in-place) | Top-level entry to fuse supported nodes into subgroups |
| mergeNodeIntoGroup | fusion group Node*, producer Node* | Node* (merged inner node) | Inserts a node into an existing fusion group subgraph |
| tryFuse | consumer Node*, producer Value* | optional<Node*> | Attempts to fuse producer into consumer, returns the group node |
| run | (none) | void | Iterates fusion until convergence, then optimizes subgraphs |
Usage Examples
#include "orttraining/lazy_tensor/fusion.h"
using namespace onnxruntime::lazytensor;
// Fuse supported operations in a JIT graph
auto is_supported = [](torch::jit::Node* n) {
return Accelerator::Supported(n);
};
OrtFuseGraph(graph, is_supported, torch::jit::prim::FusionGroup, 128);
// graph now contains fusion groups that can be executed by the Accelerator