Implementation:LaurentMazare Tch rs Torch Jit Trace
| Knowledge Sources | |
|---|---|
| Domains | Model_Deployment, Interoperability |
| Last Updated | 2026-02-08 14:00 GMT |
Overview
External tool for exporting PyTorch models as TorchScript via tracing, used as a prerequisite for Rust-based inference with tch-rs.
Description
torch.jit.trace records the operations performed by a PyTorch model on example inputs, producing a TorchScript module that can be serialized to disk. The resulting .pt file is loaded in Rust via CModule::load. The tch-rs repository provides Python helper scripts (e.g., examples/jit/resnet.py) demonstrating the export process.
Usage
Run this Python script before Rust inference. The exported .pt file is the input for CModule::load in Rust.
Code Reference
Source Location
- Repository: tch-rs
- File: examples/jit/resnet.py
- Lines: 1-8
Signature
import torch
import torchvision
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
Import
import torch
import torchvision
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | PyTorch model instance (set to eval mode) |
| example | Tensor | Yes | Example input tensor matching model's expected shape |
Outputs
| Name | Type | Description |
|---|---|---|
| .pt file | File | Serialized TorchScript module containing architecture and weights |
Usage Examples
import torch
import torchvision
# Export ResNet-18 with pretrained weights
model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced = torch.jit.trace(model, example)
traced.save("resnet18.pt")