Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:LaurentMazare Tch rs Transfer Learning

From Leeroopedia
Revision as of 11:01, 16 February 2026 by Admin (talk | contribs) (Auto-imported from workflows/LaurentMazare_Tch_rs_Transfer_Learning.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Deep_Learning, Transfer_Learning, Image_Classification, Rust_ML
Last Updated 2026-02-08 13:00 GMT

Overview

End-to-end process for fine-tuning a pretrained ImageNet model on a custom image dataset using transfer learning with frozen feature extraction in tch-rs.

Description

This workflow demonstrates how to adapt a pretrained vision model (ResNet-18) to classify images from a new domain-specific dataset. The approach loads a pretrained model without its final classification layer, uses it as a frozen feature extractor to compute embeddings for all training and test images, then trains a lightweight linear classifier on top of these features. This dramatically reduces training time and data requirements compared to training from scratch.

Usage

Execute this workflow when you have a small to medium custom image dataset organized into class directories and want to leverage pretrained ImageNet features for classification. This is appropriate when training data is limited (hundreds to thousands of images per class) and the visual domain is not radically different from natural images.

Execution Steps

Step 1: Load the custom dataset

Load images from a directory structure where each subdirectory represents a class label. The imagenet::load_from_dir function handles directory scanning, image loading, resizing to 224x224, and splitting into train/test tensors with corresponding integer labels.

Key considerations:

  • Directory structure must follow the convention: base_dir/class_name/image_files
  • Images are automatically resized and normalized to ImageNet standards
  • The dataset object contains train_images, train_labels, test_images, test_labels, and the label count

Step 2: Load the pretrained feature extractor

Instantiate a ResNet-18 model without its final fully-connected layer using the resnet18_no_final_layer constructor. Load pretrained weights into this truncated model. The output of this model is a 512-dimensional feature vector per image.

Key considerations:

  • The no_final_layer variant removes the classification head, exposing the penultimate feature layer
  • Weights are loaded from a .ot file containing pretrained ImageNet parameters
  • The feature dimension depends on the architecture (512 for ResNet-18)

Step 3: Pre-compute feature embeddings

Run all training and test images through the frozen feature extractor in a no_grad context to produce compact feature tensors. This is done once upfront so the linear classifier trains on precomputed features rather than re-running the CNN on every epoch.

What happens:

  • no_grad disables gradient tracking for memory efficiency
  • apply_t with train=false ensures deterministic behavior (batch-norm uses running stats)
  • The result is a [N, 512] feature matrix for N images

Step 4: Train a linear classifier

Create a new VarStore with a single linear layer mapping from the feature dimension (512) to the number of target classes. Train this small model using SGD on the precomputed features with cross-entropy loss.

Key considerations:

  • Only the linear layer weights are trained; the CNN backbone is never updated
  • Training is very fast since the input is a low-dimensional feature matrix
  • SGD with a small learning rate (1e-3) is typically sufficient

Step 5: Evaluate accuracy

Compute classification accuracy on the precomputed test features by running them through the trained linear layer and comparing argmax predictions against ground-truth labels.

Key considerations:

  • Evaluation uses the same precomputed features from Step 3
  • accuracy_for_logits computes the fraction of correct predictions
  • The expected accuracy depends on domain similarity to ImageNet

Execution Diagram

GitHub URL

Workflow Repository