Implementation:LaurentMazare Tch rs Torch Api Cpp
| Knowledge Sources | |
|---|---|
| Domains | FFI, Deep Learning, Systems Programming |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Core C++ FFI implementation file that bridges Rust to the PyTorch C++ API, providing the hand-written C-callable functions for tensor creation, manipulation, serialization, device management, optimizer construction, JIT module interaction, IValue handling, and image I/O.
Description
torch_api.cpp is the central hand-written C++ implementation file in the torch-sys crate. It contains approximately 1671 lines of C++ code that implement all functions declared in torch_api.h. Every exported function follows a consistent pattern: it uses the PROTECT macro (defined in torch_api.h) to wrap PyTorch C++ API calls in a try/catch block, storing any exception message in the thread-local torch_last_err variable for Rust to retrieve via get_and_reset_last_err().
The file is organized into several logical sections:
- Tensor lifecycle:
at_new_tensor,at_tensor_of_data,at_tensor_of_blob,at_shallow_clone,at_free - Tensor properties:
at_defined,at_is_sparse,at_is_mkldnn,at_is_contiguous,at_dim,at_shape,at_stride,at_scalar_type,at_device - Data access:
at_copy_data,at_data_ptr,at_double_value_at_indexes,at_int64_value_at_indexes,at_fill_double,at_fill_int64 - Autograd:
at_backward,at_run_backward,at_requires_grad,at_grad_set_enabled - Serialization:
at_save,at_load,at_save_multi,at_load_multi,at_save_to_stream,at_load_from_stream,at_load_callback,at_loadz_callbackwith stream adapter classes (WriteStreamAdapter, ReadStreamAdapter) - Optimizers:
ato_adam,ato_adamw,ato_sgd,ato_rms_prop,ato_add_parameters,ato_step,ato_zero_grad,ato_set_learning_rate - JIT Modules:
atm_load,atm_forward,atm_method,atm_eval,atm_train,atm_save,atm_named_parameters,atm_create_for_tracing,atm_end_tracing - IValues:
ati_none,ati_tensor,ati_int,ati_double,ati_tuple,ati_generic_list,ati_generic_dict,ati_tag,ati_to_tensor,ati_to_int,ati_length,ati_object_method_ - Scalars:
ats_int,ats_float,ats_to_int,ats_to_float,ats_free - Image I/O:
at_load_image,at_load_image_from_memory,at_save_image,at_resize_image(using vendored stb_image libraries) - CUDA:
atc_cuda_device_count,atc_cuda_is_available,atc_cudnn_is_available,atc_manual_seed,atc_synchronize - Context queries:
at_context_has_openmp,at_context_has_cuda,at_context_has_mkl,at_context_has_mps, etc.
The file also includes the stb_image implementation defines (#define STB_IMAGE_IMPLEMENTATION, #define STB_IMAGE_WRITE_IMPLEMENTATION, #define STB_IMAGE_RESIZE_IMPLEMENTATION) which compile the vendored single-header image libraries into this translation unit.
Usage
This file is compiled as part of the torch-sys crate build process and linked against libtorch. Rust code never calls these functions directly; instead, the Rust FFI declarations in lib.rs reference these C-linkage symbols, and the higher-level tch crate provides safe wrappers. Developers extending tch-rs with new hand-written bindings would add both the C declaration in torch_api.h and the C++ implementation here.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: torch-sys/libtch/torch_api.cpp
- Lines: 1-1671
Signature
// Error handling pattern used throughout
thread_local char *torch_last_err = nullptr;
char *get_and_reset_last_err() {
char *tmp = torch_last_err;
torch_last_err = nullptr;
return tmp;
}
// Device mapping helper
at::Device device_of_int(int d) {
if (d == -3) return at::Device(at::kVulkan);
if (d == -2) return at::Device(at::kMPS);
if (d < 0) return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
}
// Representative tensor functions
tensor at_new_tensor();
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
size_t element_size_in_bytes, int type);
void at_copy_data(tensor tensor, void *vs, size_t numel,
size_t elt_size_in_bytes);
tensor at_shallow_clone(tensor t);
void at_backward(tensor t, int keep_graph, int create_graph);
void at_save(tensor t, char *filename);
tensor at_load(char *filename);
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs,
int ninputs, tensor *outputs,
int keep_graph, int create_graph);
// Optimizer constructors
optimizer ato_adam(double learning_rate, double beta1, double beta2,
double weight_decay, double eps, bool amsgrad);
// JIT Module functions
module atm_load(char *filename);
tensor atm_forward(module m, tensor *tensors, int ntensors);
// IValue functions
ivalue ati_none();
ivalue ati_tensor(tensor t);
int ati_tag(ivalue i);
Import
#include<torch/torch.h>
#include<torch/script.h>
#include<torch/csrc/autograd/engine.h>
#include<torch/csrc/jit/frontend/tracer.h>
#include "torch_api.h"
#include "stb_image.h"
#include "stb_image_write.h"
#include "stb_image_resize.h"
I/O Contract
| Function Category | Input | Output | Error Handling |
|---|---|---|---|
| Tensor Creation (at_new_tensor, at_tensor_of_data) | Raw data pointer, dimensions array, scalar type | Heap-allocated torch::Tensor* (returned as tensor) |
Returns nullptr on error; sets torch_last_err
|
| Tensor Properties (at_dim, at_shape, at_scalar_type) | tensor pointer |
Integer value or writes to output buffer | Returns -1 on error; sets torch_last_err
|
| Data Copy (at_copy_data) | Tensor pointer, destination buffer, numel, element size | Copies tensor data to caller-provided buffer | Sets torch_last_err if sizes mismatch or copy fails
|
| Serialization (at_save, at_load) | Filename string or stream pointer | File written / tensor returned | Sets torch_last_err on I/O or format errors
|
| Backward (at_run_backward) | Arrays of tensors and inputs, keep_graph/create_graph flags | Gradient tensors written to outputs array | Sets torch_last_err if inputs lack requires_grad
|
| Image I/O (at_load_image, at_save_image) | Filename string or image data buffer | Tensor (HWC, uint8, 3 channels) or status int | Sets torch_last_err; uses stbi_failure_reason() for detail
|
Usage Examples
// Creating a tensor from data (called from Rust via FFI)
float data[] = {1.0, 2.0, 3.0, 4.0};
int64_t dims[] = {2, 2};
tensor t = at_tensor_of_data(data, dims, 2, sizeof(float), /* Float */ 6);
// Performing backward pass
at_backward(t, /*keep_graph=*/0, /*create_graph=*/0);
// Saving and loading
at_save(t, "model.pt");
tensor loaded = at_load("model.pt");
// Loading an image as a tensor
tensor img = at_load_image("photo.png");
// img is shape [H, W, 3] with dtype Byte
// Error checking pattern (from Rust side)
char *err = get_and_reset_last_err();
if (err != nullptr) {
// handle error, free(err)
}
Related Pages
- Principle:LaurentMazare_Tch_rs_C_FFI_Bridge
- Implementation:LaurentMazare_Tch_rs_Torch_Api_H - C header declaring these functions
- Implementation:LaurentMazare_Tch_rs_Torch_Sys_Lib - Rust FFI bindings that call these functions
- Implementation:LaurentMazare_Tch_rs_Torch_Api_Generated_Cpp - Auto-generated companion for tensor operations
- Implementation:LaurentMazare_Tch_rs_Stb_Image - Vendored image libraries used for image I/O