Implementation:LaurentMazare Tch rs PyObject Tensor Bridge
| Knowledge Sources | |
|---|---|
| Domains | Python Interoperability, FFI, Tensor Operations |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
The python module provides unsafe FFI functions for wrapping and unwrapping Tensor objects as Python PyObject pointers, enabling tensor exchange between Rust and Python.
Description
This module is conditionally compiled behind the python-extension feature flag and provides the bridge for passing tensors between Rust and Python when building Python extensions with tch-rs.
The module defines CPyObject as a type alias for torch_sys::python::C_pyobject, representing opaque Python object pointers.
Three functions are provided:
pyobject_check is an unsafe function that takes a raw *mut CPyObject pointer and returns Result<bool, TchError>, indicating whether the Python object is a wrapped PyTorch tensor. This calls through to torch_sys::python::thp_variable_check.
Tensor::pyobject_wrap is a method on Tensor that wraps the tensor's internal C pointer into a Python object, returning Result<*mut CPyObject, TchError>. It calls torch_sys::python::thp_variable_wrap.
Tensor::pyobject_unpack is an unsafe associated function that takes a *mut CPyObject and attempts to extract a Tensor from it. It first calls pyobject_check to verify the object is a tensor; if not, it returns Ok(None). If it is a tensor, it calls torch_sys::python::thp_variable_unpack and wraps the resulting pointer with Tensor::from_ptr, returning Ok(Some(Tensor)).
All three functions use the unsafe_torch_err! macro for error handling, which checks for C++ exceptions after each FFI call.
Usage
Use these functions when building Python extensions in Rust that need to accept or return PyTorch tensors. The caller is responsible for ensuring that the raw pointers are valid Python objects.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: src/wrappers/python.rs
Signature
pub type CPyObject = torch_sys::python::C_pyobject;
/// Check whether a Python object wraps a tensor.
/// # Safety
/// Undefined behavior if the pointer is not a valid PyObject.
pub unsafe fn pyobject_check(pyobject: *mut CPyObject) -> Result<bool, TchError>;
impl Tensor {
/// Wrap a tensor in a Python object.
pub fn pyobject_wrap(&self) -> Result<*mut CPyObject, TchError>;
/// Unwrap a tensor stored in a Python object. Returns Ok(None) if not a tensor.
/// # Safety
/// Undefined behavior if the pointer is not a valid PyObject.
pub unsafe fn pyobject_unpack(pyobject: *mut CPyObject) -> Result<Option<Self>, TchError>;
}
Import
use tch::python::{pyobject_check, CPyObject};
use tch::Tensor;
I/O Contract
| Function | Input | Output | Safety |
|---|---|---|---|
| pyobject_check | *mut CPyObject | Result<bool, TchError> | Unsafe: pointer must be valid PyObject |
| pyobject_wrap | &self (Tensor) | Result<*mut CPyObject, TchError> | Safe (wraps internal pointer) |
| pyobject_unpack | *mut CPyObject | Result<Option<Tensor>, TchError> | Unsafe: pointer must be valid PyObject |
| Feature Gate | Required |
|---|---|
| python-extension | Yes, module is not available without this feature |
Usage Examples
// Requires the python-extension feature
use tch::python::{pyobject_check, CPyObject};
use tch::Tensor;
// Wrapping a tensor for return to Python
fn tensor_to_python(t: &Tensor) -> *mut CPyObject {
t.pyobject_wrap().expect("failed to wrap tensor")
}
// Unpacking a tensor received from Python
unsafe fn tensor_from_python(obj: *mut CPyObject) -> Option<Tensor> {
Tensor::pyobject_unpack(obj).expect("failed to unpack")
}
// Checking if a Python object is a tensor
unsafe fn is_tensor(obj: *mut CPyObject) -> bool {
pyobject_check(obj).unwrap_or(false)
}