Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Torch Sys Lib

From Leeroopedia


Knowledge Sources
Domains FFI, Rust, Deep Learning
Last Updated 2026-02-08 00:00 GMT

Overview

Root module of the torch-sys crate defining all hand-written Rust FFI bindings to libtorch, including opaque type definitions, tensor operations, optimizer functions, JIT module bindings, IValue handling, and scalar operations.

Description

lib.rs is the 336-line root source file of the torch-sys crate. It serves as the foundational FFI layer between Rust and the PyTorch C++ library. Unlike c_generated.rs which is auto-generated, this file is hand-written and defines the core types and carefully crafted bindings.

The file is organized into these sections:

Submodule declarations:

  • pub mod cuda; - CUDA-specific bindings
  • pub mod io; - I/O and streaming utilities
  • #[cfg(feature = "python-extension")] pub mod python; - Optional Python extension bindings
  • mod traits; - Helper traits for list conversions
  • pub mod c_generated; - Re-exports the auto-generated bindings

Opaque C types:

  • C_scalar - Opaque FFI type representing torch::Scalar* on the C++ side, defined as #[repr(C)] pub struct C_scalar { _private: [u8; 0] }
  • C_tensor - Opaque FFI type representing torch::Tensor* on the C++ side
  • C_optimizer - Opaque FFI type representing torch::optim::Optimizer*
  • CIValue - Opaque FFI type representing torch::jit::IValue*
  • CModule_ - Opaque FFI type representing torch::jit::script::Module*

Extern "C" blocks organized by functionality:

  • Scalar operations: ats_int, ats_float, ats_to_int, ats_to_float, ats_to_string, ats_free
  • Tensor lifecycle: at_new_tensor, at_tensor_of_data, at_tensor_of_blob, at_shallow_clone, at_free
  • Tensor properties: at_defined, at_is_sparse, at_is_mkldnn, at_is_contiguous, at_dim, at_shape, at_stride, at_scalar_type, at_device
  • Autograd: at_backward, at_requires_grad, at_grad_set_enabled, at_run_backward
  • Serialization: at_save, at_load, at_save_multi, at_load_callback, at_loadz_callback, stream variants
  • Autocast: at_autocast_is_enabled, at_autocast_set_enabled, at_autocast_increment_nesting
  • Threading: at_get_num_threads, at_set_num_threads, at_get_num_interop_threads
  • Context: at_context_has_cuda, at_context_has_mkl, at_context_has_mps, at_context_version_cudnn, etc.
  • Error handling: get_and_reset_last_err
  • Optimizers: ato_adam, ato_adamw, ato_sgd, ato_rms_prop, ato_add_parameters, ato_step, ato_zero_grad, ato_free
  • Image I/O: at_save_image, at_load_image, at_load_image_from_memory, at_resize_image
  • IValue constructors: ati_none, ati_bool, ati_int, ati_double, ati_tensor, ati_string, ati_tuple, ati_generic_list, ati_generic_dict
  • IValue getters: ati_to_int, ati_to_double, ati_to_tensor, ati_to_string, ati_to_tuple, ati_tag, ati_length
  • JIT Modules: atm_load, atm_load_on_device, atm_load_str, atm_forward, atm_forward_, atm_method, atm_eval, atm_train, atm_save, atm_named_parameters, atm_create_for_tracing, atm_end_tracing

Usage

This crate is a dependency of the higher-level tch crate. The types and functions defined here are used by tch to build safe Rust abstractions over PyTorch. End users of tch-rs typically never interact with torch-sys directly. Developers extending the FFI layer add new extern declarations here and corresponding C++ implementations in torch_api.cpp.

Code Reference

Source Location

Signature

// Opaque C types
#[repr(C)]
pub struct C_scalar {
    _private: [u8; 0],
}

#[repr(C)]
pub struct C_tensor {
    _private: [u8; 0],
}

#[repr(C)]
pub struct C_optimizer {
    _private: [u8; 0],
}

#[repr(C)]
pub struct CIValue {
    _private: [u8; 0],
}

#[repr(C)]
pub struct CModule_ {
    _private: [u8; 0],
}

// Representative extern declarations
extern "C" {
    pub fn at_new_tensor() -> *mut C_tensor;
    pub fn at_tensor_of_data(
        vs: *const c_void, dims: *const i64,
        ndims: size_t, elt_size_in_bytes: size_t, kind: c_int,
    ) -> *mut C_tensor;
    pub fn at_backward(arg: *mut C_tensor, keep_graph: c_int, create_graph: c_int);
    pub fn at_save(arg: *mut C_tensor, filename: *const c_char);
    pub fn at_load(filename: *const c_char) -> *mut C_tensor;
    pub fn at_free(arg: *mut C_tensor);
    pub fn get_and_reset_last_err() -> *mut c_char;
    pub fn ato_adam(lr: f64, beta1: f64, beta2: f64, wd: f64,
                   eps: f64, amsgrad: bool) -> *mut C_optimizer;
    pub fn atm_load(filename: *const c_char) -> *mut CModule_;
    pub fn atm_forward(m: *mut CModule_, args: *const *mut C_tensor,
                      n: c_int) -> *mut C_tensor;
    pub fn ati_tensor(v: *mut C_tensor) -> *mut CIValue;
    pub fn ati_tag(arg: *mut CIValue) -> c_int;
}

Import

// In the tch crate or other consumers:
use torch_sys::{C_tensor, C_scalar, C_optimizer, CIValue, CModule_};
use torch_sys::c_generated;

I/O Contract

Category Functions Return Type Error Handling
Tensor creation at_new_tensor, at_tensor_of_data, at_tensor_of_blob *mut C_tensor Returns null on error; check get_and_reset_last_err()
Tensor queries at_dim, at_scalar_type, at_device size_t or c_int Returns -1 on error; check get_and_reset_last_err()
Tensor mutation at_backward, at_copy_, at_shape void (writes to output buffer) Check get_and_reset_last_err()
Serialization at_save, at_load, at_save_multi void or *mut C_tensor Check get_and_reset_last_err()
Optimizers ato_adam, ato_sgd, ato_step *mut C_optimizer or void Check get_and_reset_last_err()
JIT Modules atm_load, atm_forward, atm_method *mut CModule_ or *mut C_tensor Check get_and_reset_last_err()
IValues ati_tensor, ati_int, ati_tag *mut CIValue or c_int/i64/f64 Check get_and_reset_last_err()
Image I/O at_load_image, at_save_image, at_resize_image *mut C_tensor or c_int Check get_and_reset_last_err()

Usage Examples

use torch_sys::*;
use std::ffi::CString;

unsafe {
    // Create a new empty tensor
    let t = at_new_tensor();

    // Check for errors after any FFI call
    let err = get_and_reset_last_err();
    if !err.is_null() {
        let msg = std::ffi::CStr::from_ptr(err).to_string_lossy();
        libc::free(err as *mut libc::c_void);
        panic!("Error: {}", msg);
    }

    // Create a tensor from data
    let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
    let dims: Vec<i64> = vec![2, 2];
    let t = at_tensor_of_data(
        data.as_ptr() as *const libc::c_void,
        dims.as_ptr(),
        2,  // ndims
        std::mem::size_of::<f32>(),  // element size
        6,  // Float scalar type
    );

    // Save and load
    let filename = CString::new("tensor.pt").unwrap();
    at_save(t, filename.as_ptr());
    let loaded = at_load(filename.as_ptr());

    // Clean up
    at_free(t);
    at_free(loaded);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment