Overview
_CudaExtension is an abstract base class for CUDA kernel extensions in ColossalAI, extending _CppExtension with CUDA-specific compilation support including nvcc flags, CUDA availability checks, and GPU architecture configuration.
Description
The _CudaExtension class builds upon _CppExtension to add CUDA-specific compilation workflows. It verifies CUDA availability (via torch.cuda.is_available() or the FORCE_CUDA environment variable), checks compatibility by validating CUDA_HOME and matching PyTorch/system CUDA versions, and includes the CUDA home include directory in the compilation paths. Both build_aot and build_jit methods set the CUDA architecture list before compilation and pass nvcc flags alongside C++ flags.
Usage
Use _CudaExtension as the base class when creating new CUDA kernel extensions for ColossalAI. Subclasses must implement nvcc_flags, sources_files, include_dirs, and cxx_flags to specify the CUDA and C++ compilation inputs. The system must have CUDA Toolkit installed with CUDA_HOME properly set and a minimum PyTorch version of 1.10.
Code Reference
Source Location
Signature
class _CudaExtension(_CppExtension):
@abstractmethod
def nvcc_flags(self) -> List[str]:
...
def is_available(self) -> bool:
...
def assert_compatible(self) -> None:
...
def get_cuda_home_include(self):
...
def include_dirs(self) -> List[str]:
...
def build_jit(self) -> None:
...
def build_aot(self) -> "CUDAExtension":
...
Import
from extensions.cuda_extension import _CudaExtension
I/O Contract
Inputs
| Name |
Type |
Required |
Description
|
| name |
str |
Yes |
The name identifier for the CUDA extension (inherited from _CppExtension)
|
| priority |
int |
No |
Priority level for extension selection (default: 1, inherited from _CppExtension)
|
Outputs
| Name |
Type |
Description
|
| is_available return |
bool |
True if CUDA is available on the system (or FORCE_CUDA is set)
|
| build_aot return |
CUDAExtension |
A PyTorch CUDAExtension object configured for ahead-of-time compilation with both cxx and nvcc flags
|
| build_jit return |
module |
The JIT-compiled and loaded CUDA kernel module
|
| get_cuda_home_include return |
str |
The path to the CUDA include directory (e.g., /usr/local/cuda/include)
|
Requirements
| Requirement |
Details
|
| CUDA Toolkit |
Must be installed with CUDA_HOME environment variable set
|
| PyTorch Version |
Minimum 1.10 (major=1, minor=10)
|
| CUDA Version Match |
System CUDA major version must match PyTorch's compiled CUDA version
|
| FORCE_CUDA |
Optional environment variable to force CUDA extension building on systems without a GPU device
|
Abstract Methods (Must Override)
| Method |
Return Type |
Description
|
| nvcc_flags() |
List[str] |
Must return a list of NVCC compiler flags (base returns ["-DCOLOSSAL_WITH_CUDA"])
|
| sources_files() |
List[str] |
Must return a list of source file paths (inherited from _CppExtension)
|
| cxx_flags() |
List[str] |
Must return a list of C++ compiler flags (inherited from _CppExtension)
|
Usage Examples
from extensions.cuda_extension import _CudaExtension
from typing import List
class MyCudaKernel(_CudaExtension):
def __init__(self):
super().__init__(name="my_cuda_kernel", priority=1)
def sources_files(self) -> List[str]:
return [
self.csrc_abs_path("my_kernel.cu"),
self.csrc_abs_path("my_kernel_wrapper.cpp"),
self.pybind_abs_path("my_kernel_pybind.cpp"),
]
def include_dirs(self) -> List[str]:
return super().include_dirs()
def cxx_flags(self) -> List[str]:
return ["-O3", "-std=c++17"] + self.version_dependent_macros
def nvcc_flags(self) -> List[str]:
return super().nvcc_flags() + ["-O3", "--use_fast_math"]
# Check availability and load
kernel = MyCudaKernel()
if kernel.is_available():
kernel.assert_compatible()
op = kernel.load()
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.