Implementation:Hpcaitech ColossalAI Cpp Extension
| Knowledge Sources | |
|---|---|
| Domains | Kernel Extensions, Build System, C++ Integration |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
_CppExtension is an abstract base class for C++ kernel extensions in ColossalAI, providing concrete implementations for ahead-of-time and just-in-time compilation using PyTorch's CppExtension and load utilities.
Description
The _CppExtension class extends _Extension and implements the AOT and JIT build workflows for pure C++ (non-CUDA) kernel extensions. It manages prebuilt module paths under the colossalai._C namespace, provides utility methods for resolving absolute paths to source files and pybind directories, and includes a caching mechanism to avoid redundant compilation. The load method first attempts to import a prebuilt module and falls back to JIT compilation if the import fails.
Usage
Use _CppExtension as the base class when creating a new C++ kernel extension for ColossalAI that does not require CUDA. Subclasses must implement sources_files, include_dirs, and cxx_flags to specify the compilation inputs.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: extensions/cpp_extension.py
- Lines: 1-138
Signature
class _CppExtension(_Extension):
def __init__(self, name: str, priority: int = 1):
...
def csrc_abs_path(self, path):
...
def pybind_abs_path(self, path):
...
def relative_to_abs_path(self, code_path: str) -> str:
...
def strip_empty_entries(self, args):
...
def import_op(self):
...
def build_aot(self) -> "CppExtension":
...
def build_jit(self) -> None:
...
@abstractmethod
def sources_files(self) -> List[str]:
...
@abstractmethod
def include_dirs(self) -> List[str]:
...
@abstractmethod
def cxx_flags(self) -> List[str]:
...
def load(self):
...
Import
from extensions.cpp_extension import _CppExtension
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | Yes | The name identifier for the C++ extension |
| priority | int | No | Priority level for extension selection (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| build_aot return | CppExtension | A PyTorch CppExtension object configured for ahead-of-time compilation |
| build_jit return | module | The JIT-compiled and loaded kernel module |
| load return | module | The loaded kernel module (prebuilt import or JIT fallback) |
| csrc_abs_path return | str | Absolute path to a file within the csrc directory |
| pybind_abs_path return | str | Absolute path to a file within the pybind directory |
Abstract Methods (Must Override)
| Method | Return Type | Description |
|---|---|---|
| sources_files() | List[str] | Must return a list of C++ source file paths for compilation |
| include_dirs() | List[str] | Must return a list of include directory paths |
| cxx_flags() | List[str] | Must return a list of C++ compiler flags |
Usage Examples
from extensions.cpp_extension import _CppExtension
from typing import List
class MyCppKernel(_CppExtension):
def __init__(self):
super().__init__(name="my_cpp_kernel", priority=1)
def sources_files(self) -> List[str]:
return [
self.csrc_abs_path("my_kernel.cpp"),
self.pybind_abs_path("my_kernel_pybind.cpp"),
]
def include_dirs(self) -> List[str]:
return [self.csrc_abs_path("")]
def cxx_flags(self) -> List[str]:
return ["-O3", "-std=c++17"] + self.version_dependent_macros
def is_available(self) -> bool:
return True
def assert_compatible(self) -> None:
pass
# Load the extension (tries prebuilt, falls back to JIT)
kernel = MyCppKernel()
op = kernel.load()