Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:LaurentMazare Tch rs Torch Api H

From Leeroopedia


Knowledge Sources
Domains FFI, C API, Deep Learning
Last Updated 2026-02-08 00:00 GMT

Overview

Master C FFI header defining all hand-written bindings between Rust and PyTorch C++, including opaque type definitions, the PROTECT error-handling macro, and declarations for tensor, optimizer, module, IValue, scalar, CUDA, and image operations.

Description

torch_api.h is a 295-line header file that serves as the central contract between the Rust FFI layer and the C++ implementation. It is the most architecturally significant header in the torch-sys crate because it:

  1. Defines opaque pointer types using a conditional compilation pattern:
    • When compiled as C++ (__cplusplus defined): types are real C++ types like torch::Tensor*, torch::Scalar*, torch::optim::Optimizer*, etc.
    • When compiled as C: all types are void*
  1. Defines the PROTECT macro which is the universal error-handling mechanism for all FFI functions:
#define PROTECT(x) \
  try { \
    x \
  } catch (const exception& e) { \
    torch_last_err = strdup(e.what()); \
  }

This macro wraps every FFI function body, catching C++ exceptions and storing the error message in a thread-local char* variable. The Rust side retrieves errors via get_and_reset_last_err().

  1. Declares all hand-written C functions organized by category:

Tensor operations: at_new_tensor, at_tensor_of_data, at_tensor_of_blob, at_shallow_clone, at_copy_data, at_data_ptr, at_defined, at_is_sparse, at_is_mkldnn, at_is_contiguous, at_device, at_dim, at_shape, at_stride, at_scalar_type, at_backward, at_requires_grad, at_grad_set_enabled, at_get, at_fill_double, at_fill_int64, at_double_value_at_indexes, at_int64_value_at_indexes, at_copy_, at_print, at_to_string, at_free, at_run_backward

Autocast: at__amp_non_finite_check_and_unscale, at_autocast_clear_cache, at_autocast_decrement_nesting, at_autocast_increment_nesting, at_autocast_is_enabled, at_autocast_set_enabled

Serialization: at_save, at_save_to_stream, at_load, at_load_from_stream, at_save_multi, at_save_multi_to_stream, at_load_multi, at_load_multi_, at_loadz_callback, at_loadz_callback_with_device, at_load_callback, at_load_callback_with_device, at_load_from_stream_callback

Image I/O: at_load_image, at_load_image_from_memory, at_save_image, at_resize_image

Optimizers: ato_adam, ato_adamw, ato_rms_prop, ato_sgd, ato_add_parameters, ato_set_learning_rate, ato_set_momentum, ato_set_weight_decay, ato_zero_grad, ato_step, ato_free (with per-group variants)

Scalars: ats_int, ats_float, ats_to_int, ats_to_float, ats_to_string, ats_free

Context queries: at_context_has_openmp, at_context_has_mkl, at_context_has_lapack, at_context_has_mkldnn, at_context_has_magma, at_context_has_cuda, at_context_has_cudart, at_context_has_cudnn, at_context_has_cusolver, at_context_has_hip, at_context_has_ipu, at_context_has_xla, at_context_has_lazy, at_context_has_mps, at_context_version_cudnn, at_context_version_cudart

CUDA: atc_cuda_device_count, atc_cuda_is_available, atc_cudnn_is_available, atc_manual_seed, atc_manual_seed_all, atc_synchronize, atc_user_enabled_cudnn, atc_set_user_enabled_cudnn, atc_set_benchmark_cudnn

JIT Modules: atm_load, atm_load_on_device, atm_load_str, atm_load_str_on_device, atm_forward, atm_forward_, atm_method, atm_method_, atm_create_class_, atm_eval, atm_train, atm_free, atm_to, atm_save, atm_named_parameters, atm_create_for_tracing, atm_end_tracing

IValues: ati_none, ati_tensor, ati_int, ati_double, ati_bool, ati_string, ati_tuple, ati_generic_list, ati_generic_dict, ati_int_list, ati_double_list, ati_bool_list, ati_string_list, ati_tensor_list, ati_device, ati_to_tensor, ati_to_int, ati_to_double, ati_to_string, ati_to_bool, ati_length, ati_tuple_length, ati_to_tuple, ati_to_generic_list, ati_to_generic_dict, ati_tag, ati_object_method_, ati_object_getattr_, ati_clone, ati_free

Stream I/O (internal): tch_write_stream_destructor, tch_write_stream_write, tch_read_stream_destructor, tch_read_stream_stream_position, tch_read_stream_seek_start, tch_read_stream_seek_end, tch_read_stream_read

C++-only helpers (outside extern "C"): of_carray_tensor, device_of_int, of_carray_tensor_opt

Usage

This header is included by torch_api.cpp, torch_api_generated.cpp (via torch_api_generated.h), and any other C/C++ files that need the FFI type definitions. It is the single source of truth for the C ABI contract. Adding a new hand-written FFI function requires declaring it here first.

Code Reference

Source Location

Signature

#ifndef __TORCH_API_H__
#define __TORCH_API_H__
#include<stdint.h>

#ifdef __cplusplus
#include<torch/torch.h>
#include<stdexcept>
using namespace std;
extern thread_local char *torch_last_err;

extern "C" {
typedef torch::Tensor *tensor;
typedef torch::Scalar *scalar;
typedef torch::optim::Optimizer *optimizer;
typedef torch::jit::script::Module *module;
typedef torch::jit::IValue *ivalue;
#define PROTECT(x) \
  try { \
    x \
  } catch (const exception& e) { \
      torch_last_err = strdup(e.what()); \
  }
#else
typedef void *tensor;
typedef void *optimizer;
typedef void *scalar;
typedef void *module;
typedef void *ivalue;
#endif

char *get_and_reset_last_err(); // thread-local
void at_manual_seed(int64_t);
tensor at_new_tensor();
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
                         size_t element_size_in_bytes, int type);
void at_backward(tensor, int, int);
void at_save(tensor, char *filename);
tensor at_load(char *filename);
void at_free(tensor);
// ... additional declarations ...

#ifdef __cplusplus
};
std::vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len);
at::Device device_of_int(int d);
#endif
#endif

Import

// Included by C++ implementation files:
#include "torch_api.h"

// Also included transitively by torch_api_generated.h:
// #include "torch_api.h"

I/O Contract

Type C++ Definition C Definition Rust Equivalent
tensor torch::Tensor* void* *mut C_tensor
scalar torch::Scalar* void* *mut C_scalar
optimizer torch::optim::Optimizer* void* *mut C_optimizer
module torch::jit::script::Module* void* *mut CModule_
ivalue torch::jit::IValue* void* *mut CIValue
Device Encoding Meaning
-3 Vulkan
-2 MPS (Apple Metal)
-1 (or any negative except -2, -3) CPU
>= 0 CUDA device index

Usage Examples

// The header defines the ABI contract.
// C++ implementation (torch_api.cpp) uses real types:
tensor at_new_tensor() {
  PROTECT(
    return new torch::Tensor();  // tensor = torch::Tensor*
  )
  return nullptr;
}

// Error handling pattern:
// 1. Call any at_*/ato_*/atm_*/ati_* function
// 2. Check thread-local error:
char *err = get_and_reset_last_err();
if (err != NULL) {
    fprintf(stderr, "Error: %s\n", err);
    free(err);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment