Implementation:LaurentMazare Tch rs Resnet18
| Knowledge Sources | |
|---|---|
| Domains | Computer_Vision, Model_Architecture |
| Last Updated | 2026-02-08 14:00 GMT |
Overview
Concrete tool for instantiating a ResNet-18 vision model provided by the tch vision module.
Description
resnet::resnet18 creates a ResNet-18 model as a FuncT closure implementing ModuleT. The model consists of conv1 (7x7), batch norm, 4 residual layers (2 blocks each with 64/128/256/512 channels), adaptive average pooling, and a final fully-connected layer mapping to num_classes. All parameters are registered in the provided VarStore path.
Usage
Use this to create a ResNet-18 model for image classification. Load pretrained weights via VarStore::load after instantiation. The model expects input of shape [batch, 3, 224, 224].
Code Reference
Source Location
- Repository: tch-rs
- File: src/vision/resnet.rs
- Lines: 78-80
Signature
pub fn resnet18(p: &nn::Path, num_classes: i64) -> FuncT<'static>
Import
use tch::vision::resnet;
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| p | &nn::Path | Yes | VarStore path for parameter registration |
| num_classes | i64 | Yes | Number of output classes (1000 for ImageNet) |
Outputs
| Name | Type | Description |
|---|---|---|
| FuncT<'static> | impl ModuleT | ResNet-18 model implementing ModuleT (forward_t with train flag) |
Usage Examples
use tch::{nn, nn::ModuleT, vision::resnet, vision::imagenet, Device, Kind};
let mut vs = nn::VarStore::new(Device::Cpu);
let model = resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT);
// Load pretrained weights
vs.load("resnet18.ot")?;
// Inference
let image = imagenet::load_image_and_resize224("photo.jpg")?;
let output = model.forward_t(&image.unsqueeze(0), false);
let probs = output.softmax(-1, Kind::Float);