Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Cpp Extension

From Leeroopedia


Knowledge Sources
Domains Kernel Extensions, Build System, C++ Integration
Last Updated 2026-02-09 00:00 GMT

Overview

_CppExtension is an abstract base class for C++ kernel extensions in ColossalAI, providing concrete implementations for ahead-of-time and just-in-time compilation using PyTorch's CppExtension and load utilities.

Description

The _CppExtension class extends _Extension and implements the AOT and JIT build workflows for pure C++ (non-CUDA) kernel extensions. It manages prebuilt module paths under the colossalai._C namespace, provides utility methods for resolving absolute paths to source files and pybind directories, and includes a caching mechanism to avoid redundant compilation. The load method first attempts to import a prebuilt module and falls back to JIT compilation if the import fails.

Usage

Use _CppExtension as the base class when creating a new C++ kernel extension for ColossalAI that does not require CUDA. Subclasses must implement sources_files, include_dirs, and cxx_flags to specify the compilation inputs.

Code Reference

Source Location

Signature

class _CppExtension(_Extension):
    def __init__(self, name: str, priority: int = 1):
        ...

    def csrc_abs_path(self, path):
        ...

    def pybind_abs_path(self, path):
        ...

    def relative_to_abs_path(self, code_path: str) -> str:
        ...

    def strip_empty_entries(self, args):
        ...

    def import_op(self):
        ...

    def build_aot(self) -> "CppExtension":
        ...

    def build_jit(self) -> None:
        ...

    @abstractmethod
    def sources_files(self) -> List[str]:
        ...

    @abstractmethod
    def include_dirs(self) -> List[str]:
        ...

    @abstractmethod
    def cxx_flags(self) -> List[str]:
        ...

    def load(self):
        ...

Import

from extensions.cpp_extension import _CppExtension

I/O Contract

Inputs

Name Type Required Description
name str Yes The name identifier for the C++ extension
priority int No Priority level for extension selection (default: 1)

Outputs

Name Type Description
build_aot return CppExtension A PyTorch CppExtension object configured for ahead-of-time compilation
build_jit return module The JIT-compiled and loaded kernel module
load return module The loaded kernel module (prebuilt import or JIT fallback)
csrc_abs_path return str Absolute path to a file within the csrc directory
pybind_abs_path return str Absolute path to a file within the pybind directory

Abstract Methods (Must Override)

Method Return Type Description
sources_files() List[str] Must return a list of C++ source file paths for compilation
include_dirs() List[str] Must return a list of include directory paths
cxx_flags() List[str] Must return a list of C++ compiler flags

Usage Examples

from extensions.cpp_extension import _CppExtension
from typing import List

class MyCppKernel(_CppExtension):
    def __init__(self):
        super().__init__(name="my_cpp_kernel", priority=1)

    def sources_files(self) -> List[str]:
        return [
            self.csrc_abs_path("my_kernel.cpp"),
            self.pybind_abs_path("my_kernel_pybind.cpp"),
        ]

    def include_dirs(self) -> List[str]:
        return [self.csrc_abs_path("")]

    def cxx_flags(self) -> List[str]:
        return ["-O3", "-std=c++17"] + self.version_dependent_macros

    def is_available(self) -> bool:
        return True

    def assert_compatible(self) -> None:
        pass

# Load the extension (tries prebuilt, falls back to JIT)
kernel = MyCppKernel()
op = kernel.load()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment