Implementation:LaurentMazare Tch rs Tensor Iter
| Knowledge Sources | |
|---|---|
| Domains | Tensor Operations, Iterator, Rust Idioms |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
The iter module provides an Iterator implementation for extracting scalar values from 1D tensors and Sum implementations for accumulating collections of tensors.
Description
This module defines the generic Iter<T> struct that enables element-wise iteration over 1D tensors. The struct holds an index, the length (from size1()), a shallow clone of the source tensor, and a PhantomData<T> marker for type safety. The Tensor::iter<T> method constructs an Iter<T> instance, returning a Result that fails if the tensor is not 1-dimensional.
Two Iterator trait implementations are provided:
- Iter<i64> extracts elements using int64_value(&[index])
- Iter<f64> extracts elements using double_value(&[index])
Both advance through the tensor element by element, returning None once the index reaches the length.
The module also implements the std::iter::Sum trait for Tensor in two forms:
- Sum for Tensor (owned): Folds an iterator of owned tensors by addition, starting from the first element or returning Tensor::from(0.) for empty iterators.
- Sum<&'a Tensor> for Tensor (borrowed): Same logic but for references, using shallow_clone for the initial accumulator.
Usage
Use Tensor::iter::<i64>() or Tensor::iter::<f64>() to iterate over 1D tensor elements in Rust loops. Use the Sum implementations with Iterator::sum() or slice::iter().sum() to accumulate tensors.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: src/tensor/iter.rs
Signature
pub struct Iter<T> {
index: i64,
len: i64,
content: Tensor,
phantom: std::marker::PhantomData<T>,
}
impl Tensor {
pub fn iter<T>(&self) -> Result<Iter<T>, TchError>;
}
impl std::iter::Iterator for Iter<i64> {
type Item = i64;
fn next(&mut self) -> Option<Self::Item>;
}
impl std::iter::Iterator for Iter<f64> {
type Item = f64;
fn next(&mut self) -> Option<Self::Item>;
}
impl std::iter::Sum for Tensor {
fn sum<I: Iterator<Item = Tensor>>(iter: I) -> Tensor;
}
impl<'a> std::iter::Sum<&'a Tensor> for Tensor {
fn sum<I: Iterator<Item = &'a Tensor>>(iter: I) -> Tensor;
}
Import
use tch::Tensor;
I/O Contract
| Method / Trait | Input | Output | Error Condition |
|---|---|---|---|
| Tensor::iter::<i64>() | 1D tensor | Iter<i64> iterator | Not 1D: TchError |
| Tensor::iter::<f64>() | 1D tensor | Iter<f64> iterator | Not 1D: TchError |
| Iter<i64>::next() | Internal state | Option<i64> | None when exhausted |
| Iter<f64>::next() | Internal state | Option<f64> | None when exhausted |
| Sum for Tensor | Iterator<Item = Tensor> | Tensor (sum) | Empty yields Tensor::from(0.) |
| Sum<&Tensor> for Tensor | Iterator<Item = &Tensor> | Tensor (sum) | Empty yields Tensor::from(0.) |
Usage Examples
use tch::{Tensor, Kind, Device};
// Iterate over elements of a 1D tensor
let t = Tensor::from_slice(&[10i64, 20, 30, 40]);
let values: Vec<i64> = t.iter::<i64>().unwrap().collect();
assert_eq!(values, vec![10, 20, 30, 40]);
// Iterate as f64
let t = Tensor::from_slice(&[1.5f64, 2.5, 3.5]);
for val in t.iter::<f64>().unwrap() {
println!("{}", val);
}
// Sum a collection of tensors
let tensors = vec![
Tensor::from_slice(&[1.0f32, 2.0]),
Tensor::from_slice(&[3.0f32, 4.0]),
Tensor::from_slice(&[5.0f32, 6.0]),
];
let total: Tensor = tensors.into_iter().sum();
// total contains [9.0, 12.0]
// Sum borrowed tensors
let t1 = Tensor::from(1.0);
let t2 = Tensor::from(2.0);
let refs = vec![&t1, &t2];
let total: Tensor = refs.into_iter().sum();