Implementation:LaurentMazare Tch rs VarStore Load
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Model_Serialization |
| Last Updated | 2026-02-08 14:00 GMT |
Overview
Concrete tool for loading saved model weights into a VarStore provided by the tch nn module.
Description
VarStore::load reads weights from a file and copies them into the VarStore's named variables. Format is auto-detected from file extension: .safetensors for safetensors format, .bin or .pt for pickle format, and anything else for libtorch C++ format (typically .ot). The copy is performed under no_grad to avoid gradient tracking. For MPS devices, a workaround temporarily moves to CPU for loading.
Usage
Call on a mutable VarStore reference after defining all model layers. The file must contain named tensors matching the VarStore's parameter names.
Code Reference
Source Location
- Repository: tch-rs
- File: src/nn/var_store.rs
- Lines: 235-249
Signature
pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError>
Import
use tch::nn;
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| path | T: AsRef<Path> | Yes | Path to weight file (.ot, .safetensors, .bin, .pt) |
Outputs
| Name | Type | Description |
|---|---|---|
| Result<()> | () | Success or TchError if file cannot be read or format is invalid |
Usage Examples
use tch::{nn, vision::resnet, vision::imagenet, Device};
let mut vs = nn::VarStore::new(Device::Cpu);
let model = resnet::resnet18(&vs.root(), imagenet::CLASS_COUNT);
// Load pretrained weights (format auto-detected)
vs.load("resnet18.ot")?;
// Also supports safetensors format
// vs.load("model.safetensors")?;