Heuristic:LaurentMazare Tch rs MPS Weight Loading Workaround
| Knowledge Sources | |
|---|---|
| Domains | Debugging, Infrastructure |
| Last Updated | 2026-02-08 13:00 GMT |
Overview
When loading model weights on Apple MPS devices, temporarily switch the VarStore to CPU for the load operation, then switch back to MPS, to avoid libtorch MPS loading limitations.
Description
A known limitation in libtorch prevents direct weight loading onto MPS (Metal Performance Shaders) devices. The tch-rs library implements an automatic workaround in `VarStore::load()`: when the target device is MPS, the VarStore is temporarily moved to CPU, the weights are loaded, and then the VarStore is moved back to MPS. This workaround is implemented defensively, ensuring the device is always restored to MPS even if the load operation fails. This pattern is essential for any Apple Silicon user running GPU-accelerated inference.
Usage
This heuristic is automatically applied by `VarStore::load()` when the VarStore device is `Device::Mps`. Users do not need to manually implement this workaround. However, understanding it is important for debugging weight loading issues on Apple Silicon and for custom loading code that bypasses `VarStore::load()`.
The Insight (Rule of Thumb)
- Action: When loading weights on MPS, call `vs.set_device(Device::Cpu)` before loading, then `vs.set_device(Device::Mps)` after.
- Value: N/A (pattern-based, not a numeric parameter).
- Trade-off: Adds a CPU-to-MPS device transfer after loading, which increases loading time but ensures correctness.
- Compatibility: Only affects MPS (Apple Silicon). CPU and CUDA loading works directly.
Reasoning
libtorch's MPS backend does not fully support all tensor deserialization paths. The underlying C++ library may fail or produce incorrect results when loading tensors directly onto the MPS device. By routing through CPU, the standard and well-tested CPU deserialization path is used, and then the tensors are transferred to MPS via the standard device migration API. The defensive error handling (ensuring MPS is restored even on load failure) prevents the VarStore from being left in an inconsistent CPU state.
Code Evidence
MPS workaround from `src/nn/var_store.rs:235-249`:
pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
if self.device != Device::Mps {
self.load_internal(path)
} else {
// Current workaround to allow loading in MPS device.
// On new libtorch releases check if direct loading becomes possible and revert
// See (https://github.com/LaurentMazare/tch-rs/issues/609#issuecomment-1427071598).
self.set_device(Device::Cpu);
let or_error = self.load_internal(path);
// Be cautious not to early exit so as to ensure that the device is set back to Mps
// even on errors.
self.set_device(Device::Mps);
or_error
}
}