Implementation:Alibaba MNN Torch Model Export
Appearance
| Field | Value |
|---|---|
| Implementation Name | Torch_Model_Export |
| Type | External Tool Doc |
| Category | Model_Conversion_Pipeline |
| Source | N/A (user-side PyTorch / TensorFlow / ONNX APIs) |
| External Dependencies | torch, tensorflow, onnx, onnxruntime |
Summary
This implementation covers the user-side APIs for serializing trained models into formats that MNNConvert can consume. Since these are external framework APIs (not part of MNN itself), this page serves as a reference for the export step that precedes MNN model conversion.
API Signatures
PyTorch TorchScript Tracing
torch.jit.trace(model, example_input) -> torch.jit.ScriptModule
- model (
torch.nn.Module) -- The trained model ineval()mode. - example_input (
torch.Tensorortuple) -- A representative input tensor (or tuple of tensors) with the correct shape, dtype, and device. - Returns -- A
ScriptModulecontaining the traced computation graph.
PyTorch ONNX Export
torch.onnx.export(
model, # torch.nn.Module in eval mode
args, # tuple of example inputs
f, # str or file-like, output path (e.g., "model.onnx")
opset_version=13, # int, ONNX opset version
input_names=None, # list of str, names for input tensors
output_names=None, # list of str, names for output tensors
dynamic_axes=None # dict, axes that can vary at runtime
)
- model (
torch.nn.Module) -- The trained model ineval()mode. - args (
tuple) -- Example input tensors for tracing. - f (
str) -- Output file path for the.onnxfile. - opset_version (
int, default 13) -- The ONNX opset version to target. - input_names (
list[str]) -- Optional names for graph input nodes. - output_names (
list[str]) -- Optional names for graph output nodes. - dynamic_axes (
dict) -- Maps input/output names to axes that may have variable sizes (e.g., batch dimension).
TensorFlow SavedModel / Frozen Graph
# TensorFlow 2.x SavedModel export
tf.saved_model.save(model, export_dir)
# Convert SavedModel to frozen graph (.pb)
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph(
input_graph=None,
input_saved_model_dir=export_dir,
output_node_names="output_node",
output_graph="frozen_model.pb"
)
TFLite Conversion
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
f.write(tflite_model)
Key Parameters
| Parameter | Type | Description |
|---|---|---|
| model | nn.Module / TF model |
Trained model object, must be in eval/inference mode |
| example_input | Tensor |
Representative input for tracing; shape must match expected inference input |
| opset_version | int |
ONNX opset version (default: 13). Higher versions support more operators |
| dynamic_axes | dict |
Specifies which tensor axes can vary (e.g., batch size) |
| input_names | list[str] |
Human-readable names assigned to model inputs in the exported graph |
| output_names | list[str] |
Human-readable names assigned to model outputs in the exported graph |
Inputs
- Trained model checkpoint or in-memory model object:
- PyTorch:
.pt,.pthcheckpoint files loaded viatorch.load() - TensorFlow: SavedModel directory or
.pbfrozen graph - Caffe:
.caffemodel+.prototxtpair - Pre-exported:
.onnx,.tflite
- PyTorch:
Outputs
- TorchScript --
.ptor.torchscriptfile containing the serialized ScriptModule - ONNX --
.onnxfile containing the model graph and weights in ONNX protobuf format - TensorFlow frozen graph --
.pbfile with graph definition and constant weights - Caffe --
.caffemodel(weights) +.prototxt(architecture) - TFLite --
.tfliteFlatBuffer file
Usage Examples
Export PyTorch Model to ONNX
import torch
import torchvision
# Load and prepare model
model = torchvision.models.resnet18(pretrained=True)
model.eval()
# Create example input
example_input = torch.randn(1, 3, 224, 224)
# Export to ONNX
torch.onnx.export(
model,
(example_input,),
"resnet18.onnx",
opset_version=13,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)
Export PyTorch Model via TorchScript Tracing
import torch
model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model.pt")
Export TensorFlow Model to TFLite
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model_dir")
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
f.write(tflite_model)
Common Pitfalls
- Forgetting eval mode -- Exporting a model in training mode will include dropout and use batch-norm running statistics incorrectly, producing different inference results.
- Incorrect example input shape -- The traced graph will be specialized to the input dimensions used during tracing; use
dynamic_axesfor variable dimensions. - Unsupported operators -- Not all PyTorch/TF operations have ONNX equivalents. Check operator support before export.
- Opset version mismatch -- Older opset versions may not support certain operations. When in doubt, use opset 13 or higher.
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment