Implementation:LaurentMazare Tch rs Torch Api H
| Knowledge Sources | |
|---|---|
| Domains | FFI, C API, Deep Learning |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Master C FFI header defining all hand-written bindings between Rust and PyTorch C++, including opaque type definitions, the PROTECT error-handling macro, and declarations for tensor, optimizer, module, IValue, scalar, CUDA, and image operations.
Description
torch_api.h is a 295-line header file that serves as the central contract between the Rust FFI layer and the C++ implementation. It is the most architecturally significant header in the torch-sys crate because it:
- Defines opaque pointer types using a conditional compilation pattern:
- When compiled as C++ (
__cplusplusdefined): types are real C++ types liketorch::Tensor*,torch::Scalar*,torch::optim::Optimizer*, etc. - When compiled as C: all types are
void*
- When compiled as C++ (
- Defines the PROTECT macro which is the universal error-handling mechanism for all FFI functions:
#define PROTECT(x) \
try { \
x \
} catch (const exception& e) { \
torch_last_err = strdup(e.what()); \
}
This macro wraps every FFI function body, catching C++ exceptions and storing the error message in a thread-local char* variable. The Rust side retrieves errors via get_and_reset_last_err().
- Declares all hand-written C functions organized by category:
Tensor operations:
at_new_tensor, at_tensor_of_data, at_tensor_of_blob, at_shallow_clone, at_copy_data, at_data_ptr, at_defined, at_is_sparse, at_is_mkldnn, at_is_contiguous, at_device, at_dim, at_shape, at_stride, at_scalar_type, at_backward, at_requires_grad, at_grad_set_enabled, at_get, at_fill_double, at_fill_int64, at_double_value_at_indexes, at_int64_value_at_indexes, at_copy_, at_print, at_to_string, at_free, at_run_backward
Autocast:
at__amp_non_finite_check_and_unscale, at_autocast_clear_cache, at_autocast_decrement_nesting, at_autocast_increment_nesting, at_autocast_is_enabled, at_autocast_set_enabled
Serialization:
at_save, at_save_to_stream, at_load, at_load_from_stream, at_save_multi, at_save_multi_to_stream, at_load_multi, at_load_multi_, at_loadz_callback, at_loadz_callback_with_device, at_load_callback, at_load_callback_with_device, at_load_from_stream_callback
Image I/O:
at_load_image, at_load_image_from_memory, at_save_image, at_resize_image
Optimizers:
ato_adam, ato_adamw, ato_rms_prop, ato_sgd, ato_add_parameters, ato_set_learning_rate, ato_set_momentum, ato_set_weight_decay, ato_zero_grad, ato_step, ato_free (with per-group variants)
Scalars:
ats_int, ats_float, ats_to_int, ats_to_float, ats_to_string, ats_free
Context queries:
at_context_has_openmp, at_context_has_mkl, at_context_has_lapack, at_context_has_mkldnn, at_context_has_magma, at_context_has_cuda, at_context_has_cudart, at_context_has_cudnn, at_context_has_cusolver, at_context_has_hip, at_context_has_ipu, at_context_has_xla, at_context_has_lazy, at_context_has_mps, at_context_version_cudnn, at_context_version_cudart
CUDA:
atc_cuda_device_count, atc_cuda_is_available, atc_cudnn_is_available, atc_manual_seed, atc_manual_seed_all, atc_synchronize, atc_user_enabled_cudnn, atc_set_user_enabled_cudnn, atc_set_benchmark_cudnn
JIT Modules:
atm_load, atm_load_on_device, atm_load_str, atm_load_str_on_device, atm_forward, atm_forward_, atm_method, atm_method_, atm_create_class_, atm_eval, atm_train, atm_free, atm_to, atm_save, atm_named_parameters, atm_create_for_tracing, atm_end_tracing
IValues:
ati_none, ati_tensor, ati_int, ati_double, ati_bool, ati_string, ati_tuple, ati_generic_list, ati_generic_dict, ati_int_list, ati_double_list, ati_bool_list, ati_string_list, ati_tensor_list, ati_device, ati_to_tensor, ati_to_int, ati_to_double, ati_to_string, ati_to_bool, ati_length, ati_tuple_length, ati_to_tuple, ati_to_generic_list, ati_to_generic_dict, ati_tag, ati_object_method_, ati_object_getattr_, ati_clone, ati_free
Stream I/O (internal):
tch_write_stream_destructor, tch_write_stream_write, tch_read_stream_destructor, tch_read_stream_stream_position, tch_read_stream_seek_start, tch_read_stream_seek_end, tch_read_stream_read
C++-only helpers (outside extern "C"):
of_carray_tensor, device_of_int, of_carray_tensor_opt
Usage
This header is included by torch_api.cpp, torch_api_generated.cpp (via torch_api_generated.h), and any other C/C++ files that need the FFI type definitions. It is the single source of truth for the C ABI contract. Adding a new hand-written FFI function requires declaring it here first.
Code Reference
Source Location
- Repository: LaurentMazare_Tch_rs
- File: torch-sys/libtch/torch_api.h
- Lines: 1-295
Signature
#ifndef __TORCH_API_H__
#define __TORCH_API_H__
#include<stdint.h>
#ifdef __cplusplus
#include<torch/torch.h>
#include<stdexcept>
using namespace std;
extern thread_local char *torch_last_err;
extern "C" {
typedef torch::Tensor *tensor;
typedef torch::Scalar *scalar;
typedef torch::optim::Optimizer *optimizer;
typedef torch::jit::script::Module *module;
typedef torch::jit::IValue *ivalue;
#define PROTECT(x) \
try { \
x \
} catch (const exception& e) { \
torch_last_err = strdup(e.what()); \
}
#else
typedef void *tensor;
typedef void *optimizer;
typedef void *scalar;
typedef void *module;
typedef void *ivalue;
#endif
char *get_and_reset_last_err(); // thread-local
void at_manual_seed(int64_t);
tensor at_new_tensor();
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
size_t element_size_in_bytes, int type);
void at_backward(tensor, int, int);
void at_save(tensor, char *filename);
tensor at_load(char *filename);
void at_free(tensor);
// ... additional declarations ...
#ifdef __cplusplus
};
std::vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len);
at::Device device_of_int(int d);
#endif
#endif
Import
// Included by C++ implementation files:
#include "torch_api.h"
// Also included transitively by torch_api_generated.h:
// #include "torch_api.h"
I/O Contract
| Type | C++ Definition | C Definition | Rust Equivalent |
|---|---|---|---|
tensor |
torch::Tensor* |
void* |
*mut C_tensor
|
scalar |
torch::Scalar* |
void* |
*mut C_scalar
|
optimizer |
torch::optim::Optimizer* |
void* |
*mut C_optimizer
|
module |
torch::jit::script::Module* |
void* |
*mut CModule_
|
ivalue |
torch::jit::IValue* |
void* |
*mut CIValue
|
| Device Encoding | Meaning |
|---|---|
-3 |
Vulkan |
-2 |
MPS (Apple Metal) |
-1 (or any negative except -2, -3) |
CPU |
>= 0 |
CUDA device index |
Usage Examples
// The header defines the ABI contract.
// C++ implementation (torch_api.cpp) uses real types:
tensor at_new_tensor() {
PROTECT(
return new torch::Tensor(); // tensor = torch::Tensor*
)
return nullptr;
}
// Error handling pattern:
// 1. Call any at_*/ato_*/atm_*/ati_* function
// 2. Check thread-local error:
char *err = get_and_reset_last_err();
if (err != NULL) {
fprintf(stderr, "Error: %s\n", err);
free(err);
}
Related Pages
- Principle:LaurentMazare_Tch_rs_C_FFI_Bridge
- Implementation:LaurentMazare_Tch_rs_Torch_Api_Cpp - C++ implementations of functions declared here
- Implementation:LaurentMazare_Tch_rs_Torch_Sys_Lib - Rust FFI bindings matching these declarations
- Implementation:LaurentMazare_Tch_rs_Torch_Api_Generated_H - Auto-generated companion header that includes this file
- Implementation:LaurentMazare_Tch_rs_Stb_Image - Image libraries used by image I/O functions declared here