Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime LazyTensorFusion

From Leeroopedia


Knowledge Sources
Domains Training, LazyTensor, GraphOptimization
Last Updated 2026-02-10 04:00 GMT

Overview

Implements the OrtFuser graph pass that merges fusable PyTorch JIT nodes into fusion groups for execution by the ORT Lazy Tensor Accelerator.

Description

The `fusion.cc` file implements a custom graph fusion pass (`OrtFuser`) for the PyTorch JIT IR, used by the ORT Lazy Tensor system. The fuser identifies operations that can be accelerated by ORT and merges them into subgraph-based fusion groups. Key components:

  • OrtFuser struct: The core fusion engine operating on a JIT `Block`. It uses `AliasDb` for safety analysis and a `FusionCallback` to determine fusability. Configuration includes the fusion group `Symbol` kind and a `subgraph_arg_limit_` (default 128) to prevent overly large fusion groups.
  • mergeNodeIntoGroup: Inserts a producer node into an existing fusion group's subgraph. Handles tensor inputs (added to group inputs), non-constant scalar inputs (floats, ints), constant inputs (inlined as constants in the subgraph), and special cases like `_grad_sum_to_size` with int list inputs. Remaps all internal connections.
  • createSingletonFusionGroup: Creates a new fusion group containing a single node, establishing the initial subgraph with proper input/output connections and alias tracking.
  • mergeFusionGroups: Merges two fusion groups by extracting all nodes from the producer group into the outer graph, destroying the producer, then inlining the temporary nodes into the consumer group.
  • tryFuse: Attempts to fuse a producer into a consumer. Checks fusability, topological ordering via `AliasDb::moveBeforeTopologicallyValid`, and argument count limits. Creates singleton groups as needed.
  • scanNode / run: The main fusion loop iterates in reverse topological order. The `run()` method executes repeatedly until no more fusions are found (needed for cases where multi-step fusion unlocks additional opportunities). After fusion, `optimizeFusedGraphs` applies dead code elimination, CSE, and constant pooling to each fused subgraph.
  • OrtFuseGraph: The public entry point that creates an `AliasDb`, instantiates the `OrtFuser`, runs the fusion pass, and validates with `torch::jit::Lint`.

Usage

This module is called by the ORT Lazy Tensor infrastructure to fuse supported operations in PyTorch JIT graphs before executing them through ONNX Runtime.

Code Reference

Source Location

Signature

namespace onnxruntime::lazytensor {

struct OrtFuser {
  using FusionCallback = std::function<bool(OrtFuser*, torch::jit::Node*)>;

  OrtFuser(torch::jit::AliasDb* aliasDb, torch::jit::Block* block,
           FusionCallback callback, torch::jit::Symbol kind,
           bool strict_fuser_check = false, size_t subgraph_arg_limit = 128);

  bool isFusable(torch::jit::Node* node);
  torch::jit::Node* mergeNodeIntoGroup(torch::jit::Node* group, torch::jit::Node* n);
  torch::jit::Node* createSingletonFusionGroup(torch::jit::Node* n);
  void mergeFusionGroups(torch::jit::Node* consumer_group, torch::jit::Node* producer_group);
  at::optional<torch::jit::Node*> tryFuse(torch::jit::Node* consumer, torch::jit::Value* producer);
  void run();
};

void OrtFuseGraph(std::shared_ptr<torch::jit::Graph>& graph,
                  const std::function<bool(torch::jit::Node*)>& fn,
                  torch::jit::Symbol kind, size_t arg_limit);

}  // namespace onnxruntime::lazytensor

Import

#include "orttraining/lazy_tensor/fusion.h"

I/O Contract

Function Inputs Outputs Description
OrtFuseGraph shared_ptr<Graph>, predicate fn, Symbol kind, arg_limit void (graph modified in-place) Top-level entry to fuse supported nodes into subgroups
mergeNodeIntoGroup fusion group Node*, producer Node* Node* (merged inner node) Inserts a node into an existing fusion group subgraph
tryFuse consumer Node*, producer Value* optional<Node*> Attempts to fuse producer into consumer, returns the group node
run (none) void Iterates fusion until convergence, then optimizes subgraphs

Usage Examples

#include "orttraining/lazy_tensor/fusion.h"

using namespace onnxruntime::lazytensor;

// Fuse supported operations in a JIT graph
auto is_supported = [](torch::jit::Node* n) {
  return Accelerator::Supported(n);
};

OrtFuseGraph(graph, is_supported, torch::jit::prim::FusionGroup, 128);
// graph now contains fusion groups that can be executed by the Accelerator

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment