Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:LaurentMazare Tch rs Pretrained Image Classification

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Image_Classification, Computer_Vision, Rust_ML
Last Updated 2026-02-08 13:00 GMT

Overview

End-to-end process for classifying images using pretrained vision model weights (ResNet, VGG, EfficientNet, DINOv2, etc.) loaded into tch-rs model architectures.

Description

This workflow demonstrates how to perform image classification inference in Rust using pretrained weights from well-known vision architectures. It covers loading and preprocessing an input image to the standard ImageNet dimensions, instantiating a model architecture with the correct number of output classes, loading serialized weights into the VarStore, running the forward pass to produce class probabilities, and extracting the top-k predictions with human-readable class names. The library supports over 15 model architectures including ResNet, VGG, DenseNet, Inception, MobileNet, SqueezeNet, EfficientNet, ConvMixer, and DINOv2.

Usage

Execute this workflow when you have a pretrained weight file (in .ot or .safetensors format) and want to classify an image against the ImageNet class taxonomy. This is the standard inference pipeline for deploying vision models in Rust without any training.

Execution Steps

Step 1: Load and preprocess the input image

Read an image file from disk and resize it to the standard input dimensions expected by the model (typically 224x224 for most architectures). The imagenet module provides a helper function that handles loading, resizing, and normalization in a single call.

Key considerations:

  • Images are loaded as [C, H, W] tensors with values in [0, 1] range
  • Standard ImageNet normalization is applied (mean and std per channel)
  • The load_image_and_resize224 helper handles the full preprocessing pipeline

Step 2: Instantiate the model architecture

Create the selected model architecture under a new VarStore. The model constructor is chosen based on the weight file name (e.g., resnet18, vgg16, efficientnet-b0). Each constructor takes a VarStore path and the number of output classes (1000 for ImageNet).

Key considerations:

  • The architecture must match the weight file exactly
  • All model constructors follow the same signature pattern: (path, num_classes) -> impl ModuleT
  • DINOv2 models use a different constructor that does not take a class count parameter

Step 3: Load pretrained weights

Call VarStore::load with the path to the weight file. This deserializes the saved tensors and copies them into the model's parameters by matching hierarchical path names. The .ot format uses PyTorch's native serialization; .safetensors uses the Safetensors format.

Key considerations:

  • Weight file format is detected by extension (.ot or .safetensors)
  • Parameter names in the weight file must match the VarStore path hierarchy
  • Loading performs an in-place copy into existing parameter tensors

Step 4: Run the forward pass

Apply the model to the preprocessed image tensor (unsqueezed to add a batch dimension) in evaluation mode (train=false). The output logits are converted to probabilities using softmax.

What happens:

  • Input tensor is unsqueezed from [C, H, W] to [1, C, H, W]
  • forward_t is called with train=false to disable dropout and use batch-norm running stats
  • Softmax converts raw logits to a probability distribution over classes

Step 5: Extract and display top-k predictions

Use the imagenet::top utility to extract the highest-probability class indices and their corresponding human-readable labels from the 1000-class ImageNet taxonomy. The function returns (probability, class_name) pairs sorted by descending confidence.

Key considerations:

  • The top function operates on the softmax output tensor
  • Class names are embedded in the imagenet module as a static lookup table
  • Typical usage extracts top-5 predictions

Execution Diagram

GitHub URL

Workflow Repository