Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Torch Api Cpp

From Leeroopedia


Knowledge Sources
Domains FFI, Deep Learning, Systems Programming
Last Updated 2026-02-08 00:00 GMT

Overview

Core C++ FFI implementation file that bridges Rust to the PyTorch C++ API, providing the hand-written C-callable functions for tensor creation, manipulation, serialization, device management, optimizer construction, JIT module interaction, IValue handling, and image I/O.

Description

torch_api.cpp is the central hand-written C++ implementation file in the torch-sys crate. It contains approximately 1671 lines of C++ code that implement all functions declared in torch_api.h. Every exported function follows a consistent pattern: it uses the PROTECT macro (defined in torch_api.h) to wrap PyTorch C++ API calls in a try/catch block, storing any exception message in the thread-local torch_last_err variable for Rust to retrieve via get_and_reset_last_err().

The file is organized into several logical sections:

  • Tensor lifecycle: at_new_tensor, at_tensor_of_data, at_tensor_of_blob, at_shallow_clone, at_free
  • Tensor properties: at_defined, at_is_sparse, at_is_mkldnn, at_is_contiguous, at_dim, at_shape, at_stride, at_scalar_type, at_device
  • Data access: at_copy_data, at_data_ptr, at_double_value_at_indexes, at_int64_value_at_indexes, at_fill_double, at_fill_int64
  • Autograd: at_backward, at_run_backward, at_requires_grad, at_grad_set_enabled
  • Serialization: at_save, at_load, at_save_multi, at_load_multi, at_save_to_stream, at_load_from_stream, at_load_callback, at_loadz_callback with stream adapter classes (WriteStreamAdapter, ReadStreamAdapter)
  • Optimizers: ato_adam, ato_adamw, ato_sgd, ato_rms_prop, ato_add_parameters, ato_step, ato_zero_grad, ato_set_learning_rate
  • JIT Modules: atm_load, atm_forward, atm_method, atm_eval, atm_train, atm_save, atm_named_parameters, atm_create_for_tracing, atm_end_tracing
  • IValues: ati_none, ati_tensor, ati_int, ati_double, ati_tuple, ati_generic_list, ati_generic_dict, ati_tag, ati_to_tensor, ati_to_int, ati_length, ati_object_method_
  • Scalars: ats_int, ats_float, ats_to_int, ats_to_float, ats_free
  • Image I/O: at_load_image, at_load_image_from_memory, at_save_image, at_resize_image (using vendored stb_image libraries)
  • CUDA: atc_cuda_device_count, atc_cuda_is_available, atc_cudnn_is_available, atc_manual_seed, atc_synchronize
  • Context queries: at_context_has_openmp, at_context_has_cuda, at_context_has_mkl, at_context_has_mps, etc.

The file also includes the stb_image implementation defines (#define STB_IMAGE_IMPLEMENTATION, #define STB_IMAGE_WRITE_IMPLEMENTATION, #define STB_IMAGE_RESIZE_IMPLEMENTATION) which compile the vendored single-header image libraries into this translation unit.

Usage

This file is compiled as part of the torch-sys crate build process and linked against libtorch. Rust code never calls these functions directly; instead, the Rust FFI declarations in lib.rs reference these C-linkage symbols, and the higher-level tch crate provides safe wrappers. Developers extending tch-rs with new hand-written bindings would add both the C declaration in torch_api.h and the C++ implementation here.

Code Reference

Source Location

Signature

// Error handling pattern used throughout
thread_local char *torch_last_err = nullptr;

char *get_and_reset_last_err() {
    char *tmp = torch_last_err;
    torch_last_err = nullptr;
    return tmp;
}

// Device mapping helper
at::Device device_of_int(int d) {
    if (d == -3) return at::Device(at::kVulkan);
    if (d == -2) return at::Device(at::kMPS);
    if (d < 0) return at::Device(at::kCPU);
    return at::Device(at::kCUDA, /*index=*/d);
}

// Representative tensor functions
tensor at_new_tensor();
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
                         size_t element_size_in_bytes, int type);
void at_copy_data(tensor tensor, void *vs, size_t numel,
                  size_t elt_size_in_bytes);
tensor at_shallow_clone(tensor t);
void at_backward(tensor t, int keep_graph, int create_graph);
void at_save(tensor t, char *filename);
tensor at_load(char *filename);
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs,
                     int ninputs, tensor *outputs,
                     int keep_graph, int create_graph);

// Optimizer constructors
optimizer ato_adam(double learning_rate, double beta1, double beta2,
                  double weight_decay, double eps, bool amsgrad);

// JIT Module functions
module atm_load(char *filename);
tensor atm_forward(module m, tensor *tensors, int ntensors);

// IValue functions
ivalue ati_none();
ivalue ati_tensor(tensor t);
int ati_tag(ivalue i);

Import

#include<torch/torch.h>
#include<torch/script.h>
#include<torch/csrc/autograd/engine.h>
#include<torch/csrc/jit/frontend/tracer.h>
#include "torch_api.h"
#include "stb_image.h"
#include "stb_image_write.h"
#include "stb_image_resize.h"

I/O Contract

Function Category Input Output Error Handling
Tensor Creation (at_new_tensor, at_tensor_of_data) Raw data pointer, dimensions array, scalar type Heap-allocated torch::Tensor* (returned as tensor) Returns nullptr on error; sets torch_last_err
Tensor Properties (at_dim, at_shape, at_scalar_type) tensor pointer Integer value or writes to output buffer Returns -1 on error; sets torch_last_err
Data Copy (at_copy_data) Tensor pointer, destination buffer, numel, element size Copies tensor data to caller-provided buffer Sets torch_last_err if sizes mismatch or copy fails
Serialization (at_save, at_load) Filename string or stream pointer File written / tensor returned Sets torch_last_err on I/O or format errors
Backward (at_run_backward) Arrays of tensors and inputs, keep_graph/create_graph flags Gradient tensors written to outputs array Sets torch_last_err if inputs lack requires_grad
Image I/O (at_load_image, at_save_image) Filename string or image data buffer Tensor (HWC, uint8, 3 channels) or status int Sets torch_last_err; uses stbi_failure_reason() for detail

Usage Examples

// Creating a tensor from data (called from Rust via FFI)
float data[] = {1.0, 2.0, 3.0, 4.0};
int64_t dims[] = {2, 2};
tensor t = at_tensor_of_data(data, dims, 2, sizeof(float), /* Float */ 6);

// Performing backward pass
at_backward(t, /*keep_graph=*/0, /*create_graph=*/0);

// Saving and loading
at_save(t, "model.pt");
tensor loaded = at_load("model.pt");

// Loading an image as a tensor
tensor img = at_load_image("photo.png");
// img is shape [H, W, 3] with dtype Byte

// Error checking pattern (from Rust side)
char *err = get_and_reset_last_err();
if (err != nullptr) {
    // handle error, free(err)
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment