Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba MNN Torch Model Export

From Leeroopedia


Field Value
Implementation Name Torch_Model_Export
Type External Tool Doc
Category Model_Conversion_Pipeline
Source N/A (user-side PyTorch / TensorFlow / ONNX APIs)
External Dependencies torch, tensorflow, onnx, onnxruntime

Summary

This implementation covers the user-side APIs for serializing trained models into formats that MNNConvert can consume. Since these are external framework APIs (not part of MNN itself), this page serves as a reference for the export step that precedes MNN model conversion.

API Signatures

PyTorch TorchScript Tracing

torch.jit.trace(model, example_input) -> torch.jit.ScriptModule
  • model (torch.nn.Module) -- The trained model in eval() mode.
  • example_input (torch.Tensor or tuple) -- A representative input tensor (or tuple of tensors) with the correct shape, dtype, and device.
  • Returns -- A ScriptModule containing the traced computation graph.

PyTorch ONNX Export

torch.onnx.export(
    model,              # torch.nn.Module in eval mode
    args,               # tuple of example inputs
    f,                  # str or file-like, output path (e.g., "model.onnx")
    opset_version=13,   # int, ONNX opset version
    input_names=None,   # list of str, names for input tensors
    output_names=None,  # list of str, names for output tensors
    dynamic_axes=None   # dict, axes that can vary at runtime
)
  • model (torch.nn.Module) -- The trained model in eval() mode.
  • args (tuple) -- Example input tensors for tracing.
  • f (str) -- Output file path for the .onnx file.
  • opset_version (int, default 13) -- The ONNX opset version to target.
  • input_names (list[str]) -- Optional names for graph input nodes.
  • output_names (list[str]) -- Optional names for graph output nodes.
  • dynamic_axes (dict) -- Maps input/output names to axes that may have variable sizes (e.g., batch dimension).

TensorFlow SavedModel / Frozen Graph

# TensorFlow 2.x SavedModel export
tf.saved_model.save(model, export_dir)

# Convert SavedModel to frozen graph (.pb)
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph(
    input_graph=None,
    input_saved_model_dir=export_dir,
    output_node_names="output_node",
    output_graph="frozen_model.pb"
)

TFLite Conversion

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

Key Parameters

Parameter Type Description
model nn.Module / TF model Trained model object, must be in eval/inference mode
example_input Tensor Representative input for tracing; shape must match expected inference input
opset_version int ONNX opset version (default: 13). Higher versions support more operators
dynamic_axes dict Specifies which tensor axes can vary (e.g., batch size)
input_names list[str] Human-readable names assigned to model inputs in the exported graph
output_names list[str] Human-readable names assigned to model outputs in the exported graph

Inputs

  • Trained model checkpoint or in-memory model object:
    • PyTorch: .pt, .pth checkpoint files loaded via torch.load()
    • TensorFlow: SavedModel directory or .pb frozen graph
    • Caffe: .caffemodel + .prototxt pair
    • Pre-exported: .onnx, .tflite

Outputs

  • TorchScript -- .pt or .torchscript file containing the serialized ScriptModule
  • ONNX -- .onnx file containing the model graph and weights in ONNX protobuf format
  • TensorFlow frozen graph -- .pb file with graph definition and constant weights
  • Caffe -- .caffemodel (weights) + .prototxt (architecture)
  • TFLite -- .tflite FlatBuffer file

Usage Examples

Export PyTorch Model to ONNX

import torch
import torchvision

# Load and prepare model
model = torchvision.models.resnet18(pretrained=True)
model.eval()

# Create example input
example_input = torch.randn(1, 3, 224, 224)

# Export to ONNX
torch.onnx.export(
    model,
    (example_input,),
    "resnet18.onnx",
    opset_version=13,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

Export PyTorch Model via TorchScript Tracing

import torch

model.eval()
example_input = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model.pt")

Export TensorFlow Model to TFLite

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model("saved_model_dir")
tflite_model = converter.convert()
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

Common Pitfalls

  • Forgetting eval mode -- Exporting a model in training mode will include dropout and use batch-norm running statistics incorrectly, producing different inference results.
  • Incorrect example input shape -- The traced graph will be specialized to the input dimensions used during tracing; use dynamic_axes for variable dimensions.
  • Unsupported operators -- Not all PyTorch/TF operations have ONNX equivalents. Check operator support before export.
  • Opset version mismatch -- Older opset versions may not support certain operations. When in doubt, use opset 13 or higher.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment