Implementation:LaurentMazare Tch rs Tensor Convert
| Knowledge Sources | |
|---|---|
| Domains | Tensor Operations, Type Conversion, Interoperability |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
The convert module implements TryFrom and TryInto conversion traits between Tensor and standard Rust types (scalars, Vec, nested Vec, and ndarray arrays).
Description
This module bridges the gap between PyTorch tensors and native Rust data structures through idiomatic trait implementations. It provides several categories of conversions:
Tensor to Vec<T>: Converts a 1D tensor to a flat Vec<T> for any type implementing Element + Copy. Returns TchError::Convert if the tensor is not 1-dimensional. Internally casts the tensor to the target kind and copies data.
Tensor to Vec<Vec<T>> and Vec<Vec<Vec<T>>>: Converts 2D and 3D tensors to nested Vec structures respectively, validating dimensionality and performing a single bulk copy followed by index-based slicing.
Tensor to scalar: A from_tensor! macro generates TryFrom<&Tensor> and TryFrom<Tensor> implementations for ten scalar types: f64, f32, f16, i64, i32, i16, i8, u8, bool, and bf16. Each requires the tensor to contain exactly one element and moves it to CPU before extraction.
Tensor to ndarray::ArrayD<T>: Converts a tensor of any dimensionality to a dynamically-dimensioned ndarray, preserving the original shape.
ndarray to Tensor: Converts any ndarray::ArrayBase with contiguous storage (via as_slice) back to a tensor, reshaping to match the original ndarray shape.
Vec<T> to Tensor: Converts Vec<T> and &Vec<T> to 1D tensors using Tensor::f_from_slice.
Usage
Use these conversions when extracting tensor results into Rust-native types for post-processing, logging, or interfacing with non-tensor code. Also use them to create tensors from Rust collections or ndarray arrays.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: src/tensor/convert.rs
Signature
// Tensor -> Vec<T> (1D)
impl<T: Element + Copy> TryFrom<&Tensor> for Vec<T> { ... }
impl<T: Element + Copy> TryFrom<Tensor> for Vec<T> { ... }
// Tensor -> Vec<Vec<T>> (2D)
impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<T>> { ... }
impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<T>> { ... }
// Tensor -> Vec<Vec<Vec<T>>> (3D)
impl<T: Element + Copy> TryFrom<&Tensor> for Vec<Vec<Vec<T>>> { ... }
impl<T: Element + Copy> TryFrom<Tensor> for Vec<Vec<Vec<T>>> { ... }
// Tensor -> scalar (for f64, f32, f16, i64, i32, i16, i8, u8, bool, bf16)
impl TryFrom<&Tensor> for $typ { ... }
impl TryFrom<Tensor> for $typ { ... }
// Tensor -> ndarray::ArrayD<T>
impl<T: Element + Copy> TryInto<ndarray::ArrayD<T>> for &Tensor { ... }
// ndarray -> Tensor
impl<T, D> TryFrom<&ndarray::ArrayBase<T, D>> for Tensor { ... }
impl<T, D> TryFrom<ndarray::ArrayBase<T, D>> for Tensor { ... }
// Vec -> Tensor
impl<T: Element> TryFrom<&Vec<T>> for Tensor { ... }
impl<T: Element> TryFrom<Vec<T>> for Tensor { ... }
Import
use tch::Tensor;
use std::convert::TryFrom;
use std::convert::TryInto;
I/O Contract
| Conversion | Input Requirement | Output | Error Condition |
|---|---|---|---|
| Tensor -> Vec<T> | 1D tensor | Vec<T> | Not 1D: TchError::Convert |
| Tensor -> Vec<Vec<T>> | 2D tensor | Vec<Vec<T>> | Not 2D: TchError::Shape |
| Tensor -> Vec<Vec<Vec<T>>> | 3D tensor | Vec<Vec<Vec<T>>> | Not 3D: TchError::Shape |
| Tensor -> scalar | Single-element tensor | Scalar value | numel != 1: TchError::Convert |
| Tensor -> ndarray::ArrayD<T> | Any dimensionality | ndarray::ArrayD<T> | Shape mismatch: TchError::NdArray |
| ndarray -> Tensor | Contiguous array (as_slice succeeds) | Tensor | Non-contiguous: TchError::Convert |
| Vec<T> -> Tensor | Any Vec | 1D Tensor | Element type errors |
Usage Examples
use tch::{Tensor, Kind, Device};
use std::convert::TryFrom;
use std::convert::TryInto;
// Vec<f32> from a 1D tensor
let t = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
let v: Vec<f32> = Vec::try_from(&t).unwrap();
assert_eq!(v, vec![1.0, 2.0, 3.0]);
// Scalar extraction
let scalar_t = Tensor::from(42.0f64);
let val: f64 = f64::try_from(&scalar_t).unwrap();
assert_eq!(val, 42.0);
// Vec to Tensor
let v = vec![1i64, 2, 3, 4];
let t = Tensor::try_from(v).unwrap();
// ndarray round-trip
let arr = ndarray::array![[1.0f32, 2.0], [3.0, 4.0]];
let t = Tensor::try_from(&arr).unwrap();
let arr2: ndarray::ArrayD<f32> = (&t).try_into().unwrap();