Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Cuda Extension

From Leeroopedia


Knowledge Sources
Domains Kernel Extensions, Build System, CUDA
Last Updated 2026-02-09 00:00 GMT

Overview

_CudaExtension is an abstract base class for CUDA kernel extensions in ColossalAI, extending _CppExtension with CUDA-specific compilation support including nvcc flags, CUDA availability checks, and GPU architecture configuration.

Description

The _CudaExtension class builds upon _CppExtension to add CUDA-specific compilation workflows. It verifies CUDA availability (via torch.cuda.is_available() or the FORCE_CUDA environment variable), checks compatibility by validating CUDA_HOME and matching PyTorch/system CUDA versions, and includes the CUDA home include directory in the compilation paths. Both build_aot and build_jit methods set the CUDA architecture list before compilation and pass nvcc flags alongside C++ flags.

Usage

Use _CudaExtension as the base class when creating new CUDA kernel extensions for ColossalAI. Subclasses must implement nvcc_flags, sources_files, include_dirs, and cxx_flags to specify the CUDA and C++ compilation inputs. The system must have CUDA Toolkit installed with CUDA_HOME properly set and a minimum PyTorch version of 1.10.

Code Reference

Source Location

Signature

class _CudaExtension(_CppExtension):
    @abstractmethod
    def nvcc_flags(self) -> List[str]:
        ...

    def is_available(self) -> bool:
        ...

    def assert_compatible(self) -> None:
        ...

    def get_cuda_home_include(self):
        ...

    def include_dirs(self) -> List[str]:
        ...

    def build_jit(self) -> None:
        ...

    def build_aot(self) -> "CUDAExtension":
        ...

Import

from extensions.cuda_extension import _CudaExtension

I/O Contract

Inputs

Name Type Required Description
name str Yes The name identifier for the CUDA extension (inherited from _CppExtension)
priority int No Priority level for extension selection (default: 1, inherited from _CppExtension)

Outputs

Name Type Description
is_available return bool True if CUDA is available on the system (or FORCE_CUDA is set)
build_aot return CUDAExtension A PyTorch CUDAExtension object configured for ahead-of-time compilation with both cxx and nvcc flags
build_jit return module The JIT-compiled and loaded CUDA kernel module
get_cuda_home_include return str The path to the CUDA include directory (e.g., /usr/local/cuda/include)

Requirements

Requirement Details
CUDA Toolkit Must be installed with CUDA_HOME environment variable set
PyTorch Version Minimum 1.10 (major=1, minor=10)
CUDA Version Match System CUDA major version must match PyTorch's compiled CUDA version
FORCE_CUDA Optional environment variable to force CUDA extension building on systems without a GPU device

Abstract Methods (Must Override)

Method Return Type Description
nvcc_flags() List[str] Must return a list of NVCC compiler flags (base returns ["-DCOLOSSAL_WITH_CUDA"])
sources_files() List[str] Must return a list of source file paths (inherited from _CppExtension)
cxx_flags() List[str] Must return a list of C++ compiler flags (inherited from _CppExtension)

Usage Examples

from extensions.cuda_extension import _CudaExtension
from typing import List

class MyCudaKernel(_CudaExtension):
    def __init__(self):
        super().__init__(name="my_cuda_kernel", priority=1)

    def sources_files(self) -> List[str]:
        return [
            self.csrc_abs_path("my_kernel.cu"),
            self.csrc_abs_path("my_kernel_wrapper.cpp"),
            self.pybind_abs_path("my_kernel_pybind.cpp"),
        ]

    def include_dirs(self) -> List[str]:
        return super().include_dirs()

    def cxx_flags(self) -> List[str]:
        return ["-O3", "-std=c++17"] + self.version_dependent_macros

    def nvcc_flags(self) -> List[str]:
        return super().nvcc_flags() + ["-O3", "--use_fast_math"]

# Check availability and load
kernel = MyCudaKernel()
if kernel.is_available():
    kernel.assert_compatible()
    op = kernel.load()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment