Implementation:Onnx Onnx Shape Inference Infer Shapes
| Knowledge Sources | |
|---|---|
| Domains | Type_System, Static_Analysis |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for propagating shape and type information through ONNX models, wrapping a C++ shape inference engine.
Description
The infer_shapes function applies shape inference to a ModelProto, returning a new ModelProto with inferred type information populated in the graph's value_info field. It is a Python wrapper around the C++ shape inference engine (onnx_cpp2py_export.shape_inference). The function serializes the model to bytes, passes it to the C++ engine, and deserializes the result. For models exceeding 2GB, the companion infer_shapes_path function operates directly on file paths.
Usage
Import infer_shapes when you need to enrich a model with shape information for runtime optimization or validation. Use infer_shapes_path for models with external data that exceed 2GB. The function returns a new model (it does not modify in-place).
Code Reference
Source Location
- Repository: onnx
- File: onnx/shape_inference.py
- Lines: 32-70 (infer_shapes), 73-108 (infer_shapes_path)
Signature
def infer_shapes(
model: ModelProto | bytes,
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> ModelProto:
"""Apply shape inference to the provided ModelProto.
Args:
model: ModelProto or serialized bytes.
check_type: Check type equality for input/output.
strict_mode: Throw errors instead of stopping on failure.
data_prop: Enable data propagation for limited operators.
Returns:
New ModelProto with inferred shape information in value_info.
"""
def infer_shapes_path(
model_path: str | os.PathLike,
output_path: str | os.PathLike = "",
check_type: bool = False,
strict_mode: bool = False,
data_prop: bool = False,
) -> None:
"""Apply shape inference for models >2GB (path-based).
Args:
model_path: Input model file path.
output_path: Output path (default: overwrite input).
check_type: Check type equality.
strict_mode: Throw errors on failure.
data_prop: Enable data propagation.
"""
Import
from onnx import shape_inference
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | ModelProto or bytes | Yes | Model to infer shapes for |
| check_type | bool | No | Check type equality for inputs/outputs (default: False) |
| strict_mode | bool | No | Throw errors instead of stopping (default: False) |
| data_prop | bool | No | Enable data propagation (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | ModelProto | New model with value_info populated with inferred shapes/types |
Usage Examples
Basic Shape Inference
import onnx
from onnx import shape_inference
model = onnx.load_model("model.onnx")
# Apply shape inference
inferred_model = shape_inference.infer_shapes(model)
# Check inferred value_info
for vi in inferred_model.graph.value_info:
print(f"{vi.name}: {vi.type}")
Path-based for Large Models
from onnx import shape_inference
# For models >2GB with external data
shape_inference.infer_shapes_path(
"large_model.onnx",
output_path="large_model_inferred.onnx",
strict_mode=True,
)