Principle:Onnx Onnx Shape Inference
| Knowledge Sources | |
|---|---|
| Domains | Type_System, Static_Analysis |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
A static analysis mechanism that propagates type and shape information through an ONNX computation graph, inferring the shapes of intermediate tensors from input shapes and operator semantics.
Description
Shape inference is a graph-level analysis pass that computes the type and shape of every intermediate tensor in the graph, given the types/shapes of the graph inputs and the semantics of each operator. Each ONNX operator has a registered shape inference function that takes the input types and operator attributes and produces the output types. By applying these functions in topological order, the inference engine can determine the complete type map for the entire graph.
The inferred shapes are stored in the graph's value_info field, enriching the model with information that enables runtime optimizations (memory planning, kernel selection) and provides better error detection. Shape inference is not guaranteed to be complete: some operators or configurations may result in unknown shapes.
Usage
Use this principle after loading or constructing an ONNX model to enrich it with inferred shape information. This is particularly valuable before deployment to a runtime (since many runtimes benefit from known shapes), after model transformations that may invalidate existing shape info, and as part of validation (via checker.check_model with full_check=True).
Theoretical Basis
Shape inference is a forward dataflow analysis over the computation graph:
Failed to parse (syntax error): {\displaystyle \forall n \in \text{nodes}(G): \text{output\_types}(n) = f_{\text{op}}(\text{input\_types}(n), \text{attrs}(n)) }
Where is the shape inference function registered for each operator type.
Pseudo-code:
# Abstract shape inference algorithm
type_map = {input.name: input.type for input in graph.inputs}
for node in topological_order(graph.nodes):
input_types = [type_map[name] for name in node.inputs]
output_types = infer_op_output_types(node.op_type, input_types, node.attrs)
for name, type_info in zip(node.outputs, output_types):
type_map[name] = type_info
graph.value_info = type_map # store inferred types