Implementation:LaurentMazare Tch rs PyTensor
| Knowledge Sources | |
|---|---|
| Domains | FFI, Python Interop, Tensor |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
PyTensor is a PyO3 wrapper struct that enables seamless conversion of PyTorch tensors between Python and Rust through the FromPyObject and IntoPyObject trait implementations.
Description
The PyTensor type is a newtype wrapper around tch::Tensor defined in the pyo3-tch crate. It serves as the bridge layer that allows Rust code using tch-rs to receive torch.Tensor objects from Python and return them back. The struct implements Deref targeting tch::Tensor, so it can be used transparently wherever a tensor reference is expected.
The conversion from Python to Rust (FromPyObject) works by extracting the raw CPython pointer from the Python object, then calling tch::Tensor::pyobject_unpack to reconstruct the Rust tensor. A type check ensures the Python object is actually a torch.Tensor, raising a TypeError otherwise.
The conversion from Rust to Python (IntoPyObject) calls pyobject_wrap on the inner tensor to produce a raw Python object pointer, then wraps it into a PyO3-managed PyObject. If wrapping fails, it falls back to returning None.
A helper function wrap_tch_err converts tch::TchError into a Python ValueError, providing error propagation across the FFI boundary.
Usage
Use PyTensor when writing PyO3 #[pyfunction] or #[pymethods] that need to accept or return PyTorch tensors. Declare function parameters as PyTensor to automatically extract incoming Python tensors, and return PyTensor to automatically convert Rust tensors back to Python.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: pyo3-tch/src/lib.rs
Signature
pub struct PyTensor(pub tch::Tensor);
impl std::ops::Deref for PyTensor {
type Target = tch::Tensor;
fn deref(&self) -> &Self::Target;
}
pub fn wrap_tch_err(err: tch::TchError) -> PyErr;
impl<'source> FromPyObject<'source> for PyTensor {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self>;
}
impl<'py> IntoPyObject<'py> for PyTensor {
type Output = Bound<'py, Self::Target>;
type Target = PyAny;
type Error = PyErr;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error>;
}
Import
use pyo3_tch::PyTensor;
use pyo3_tch::wrap_tch_err;
I/O Contract
FromPyObject (Python to Rust)
| Input | Type | Description |
|---|---|---|
ob |
&Bound<'source, PyAny> |
A Python object expected to be a torch.Tensor
|
| Output | Type | Description |
| Success | PyTensor |
Wraps the extracted tch::Tensor
|
| Error (wrong type) | PyTypeError |
"expected a torch.Tensor, got {type_}"
|
| Error (unpack failure) | PyValueError |
Forwarded from tch::TchError via wrap_tch_err
|
IntoPyObject (Rust to Python)
| Input | Type | Description |
|---|---|---|
self |
PyTensor |
The Rust tensor wrapper to convert |
py |
Python<'py> |
The Python GIL token |
| Output | Type | Description |
| Success | Bound<'py, PyAny> |
A Python torch.Tensor object
|
| Wrap failure | Bound<'py, PyAny> |
Falls back to py.None()
|
wrap_tch_err
| Input | Type | Description |
|---|---|---|
err |
tch::TchError |
A tch-rs error |
| Output | Type | Description |
| Return | PyErr |
A PyValueError with the debug-formatted error message
|
Usage Examples
use pyo3::prelude::*;
use pyo3_tch::{PyTensor, wrap_tch_err};
// Accept a Python torch.Tensor and return a new one
#[pyfunction]
fn double_tensor(tensor: PyTensor) -> PyResult<PyTensor> {
let result = &*tensor * 2.0;
Ok(PyTensor(result))
}
// Use wrap_tch_err to convert tch errors to Python exceptions
#[pyfunction]
fn load_tensor(path: &str) -> PyResult<PyTensor> {
let tensor = tch::Tensor::load(path).map_err(wrap_tch_err)?;
Ok(PyTensor(tensor))
}
// Deref allows transparent access to tch::Tensor methods
#[pyfunction]
fn tensor_shape(tensor: PyTensor) -> Vec<i64> {
tensor.size() // calls tch::Tensor::size() via Deref
}