Implementation:Onnx Onnx Compose Expand Out Dim
| Knowledge Sources | |
|---|---|
| Domains | Model_Composition, Shape_Manipulation |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Concrete tool for inserting extra dimensions into ONNX model outputs provided by the ONNX compose module.
Description
The expand_out_dim function inserts an Unsqueeze node for each output in the model, adding a dimension of extent 1 at the specified index. It modifies the graph by renaming original output edges, appending Constant and Unsqueeze nodes, and updating the output ValueInfoProto with the expanded shape. The function can operate in-place or return a copy.
Usage
Import this function when a model's outputs need an additional dimension for compatibility with a downstream model during composition. Commonly used to add a batch dimension (dim_idx=0) before merging models.
Code Reference
Source Location
- Repository: onnx
- File: onnx/compose.py
- Lines: 718-751
Signature
def expand_out_dim(
model: ModelProto,
dim_idx: int,
inplace: bool | None = False,
) -> ModelProto:
"""Inserts an extra dimension with extent 1 to each output.
Args:
model: Model whose outputs to expand.
dim_idx: Index of the dimension to insert.
Negative values count from the back.
inplace: If True, mutate directly; otherwise copy first.
Returns:
ModelProto with expanded output dimensions.
"""
Import
from onnx import compose
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | ModelProto | Yes | Model whose outputs to expand |
| dim_idx | int | Yes | Dimension index to insert (negative for counting from back) |
| inplace | bool | No | Mutate directly or copy (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | ModelProto | Model with Unsqueeze nodes appended, outputs have an extra dimension |
Usage Examples
Add Batch Dimension
import onnx
from onnx import compose
model = onnx.load_model("model.onnx")
# Add batch dimension at index 0
expanded = compose.expand_out_dim(model, dim_idx=0)
# Original output [10] becomes [1, 10]
for out in expanded.graph.output:
dims = [d.dim_value for d in out.type.tensor_type.shape.dim]
print(f"{out.name}: {dims}")