Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:LaurentMazare Tch rs MPS Weight Loading Workaround

From Leeroopedia





Knowledge Sources
Domains Debugging, Infrastructure
Last Updated 2026-02-08 13:00 GMT

Overview

When loading model weights on Apple MPS devices, temporarily switch the VarStore to CPU for the load operation, then switch back to MPS, to avoid libtorch MPS loading limitations.

Description

A known limitation in libtorch prevents direct weight loading onto MPS (Metal Performance Shaders) devices. The tch-rs library implements an automatic workaround in `VarStore::load()`: when the target device is MPS, the VarStore is temporarily moved to CPU, the weights are loaded, and then the VarStore is moved back to MPS. This workaround is implemented defensively, ensuring the device is always restored to MPS even if the load operation fails. This pattern is essential for any Apple Silicon user running GPU-accelerated inference.

Usage

This heuristic is automatically applied by `VarStore::load()` when the VarStore device is `Device::Mps`. Users do not need to manually implement this workaround. However, understanding it is important for debugging weight loading issues on Apple Silicon and for custom loading code that bypasses `VarStore::load()`.

The Insight (Rule of Thumb)

  • Action: When loading weights on MPS, call `vs.set_device(Device::Cpu)` before loading, then `vs.set_device(Device::Mps)` after.
  • Value: N/A (pattern-based, not a numeric parameter).
  • Trade-off: Adds a CPU-to-MPS device transfer after loading, which increases loading time but ensures correctness.
  • Compatibility: Only affects MPS (Apple Silicon). CPU and CUDA loading works directly.

Reasoning

libtorch's MPS backend does not fully support all tensor deserialization paths. The underlying C++ library may fail or produce incorrect results when loading tensors directly onto the MPS device. By routing through CPU, the standard and well-tested CPU deserialization path is used, and then the tensors are transferred to MPS via the standard device migration API. The defensive error handling (ensuring MPS is restored even on load failure) prevents the VarStore from being left in an inconsistent CPU state.

Code Evidence

MPS workaround from `src/nn/var_store.rs:235-249`:

pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
    if self.device != Device::Mps {
        self.load_internal(path)
    } else {
        // Current workaround to allow loading in MPS device.
        // On new libtorch releases check if direct loading becomes possible and revert
        // See (https://github.com/LaurentMazare/tch-rs/issues/609#issuecomment-1427071598).
        self.set_device(Device::Cpu);
        let or_error = self.load_internal(path);
        // Be cautious not to early exit so as to ensure that the device is set back to Mps
        // even on errors.
        self.set_device(Device::Mps);
        or_error
    }
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment