Overview
Extension Utils is a collection of utility functions for validating CUDA environments, checking PyTorch version compatibility, generating CUDA compute capability flags, and configuring the build system for ColossalAI kernel extensions.
Description
This module provides essential build-system utilities for ColossalAI's extension compilation pipeline. It includes functions for retrieving CUDA versions from both the system (via nvcc) and the PyTorch build, validating that these versions match, checking minimum PyTorch version requirements, setting CUDA architecture lists for cross-compilation, generating GPU compute capability flags, and appending nvcc thread parallelism flags. A distributed-aware print_rank_0 function is also included to prevent duplicate output in multi-process environments.
Usage
Use these utility functions when building or configuring ColossalAI C++/CUDA kernel extensions. They are called internally by _CudaExtension and its subclasses to validate the build environment before compilation, but can also be used directly for custom build scripts or environment diagnostics.
Code Reference
Source Location
Signature
def print_rank_0(message: str) -> None:
...
def get_cuda_version_in_pytorch() -> List[int]:
...
def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
...
def check_system_pytorch_cuda_match(cuda_dir):
...
def get_pytorch_version() -> List[int]:
...
def check_pytorch_version(min_major_version, min_minor_version) -> bool:
...
def check_cuda_availability():
...
def set_cuda_arch_list(cuda_dir):
...
def get_cuda_cc_flag() -> List[str]:
...
def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
...
Import
from extensions.utils import (
print_rank_0,
get_cuda_version_in_pytorch,
get_cuda_bare_metal_version,
check_system_pytorch_cuda_match,
get_pytorch_version,
check_pytorch_version,
check_cuda_availability,
set_cuda_arch_list,
get_cuda_cc_flag,
append_nvcc_threads,
)
I/O Contract
Function Details
print_rank_0
| Name |
Type |
Required |
Description
|
| message |
str |
Yes |
Message to print only on rank 0 (or when torch.distributed is not initialized)
|
get_cuda_version_in_pytorch
| Name |
Type |
Description
|
| return |
Tuple[str, str] |
The (major, minor) CUDA version that PyTorch was compiled with
|
get_cuda_bare_metal_version
| Name |
Type |
Required |
Description
|
| cuda_dir |
str |
Yes |
Path to the CUDA Toolkit directory (CUDA_HOME)
|
| return |
Tuple[str, str] |
-- |
The (major, minor) system CUDA version from nvcc
|
check_system_pytorch_cuda_match
| Name |
Type |
Required |
Description
|
| cuda_dir |
str |
Yes |
Path to the CUDA Toolkit directory
|
| return |
bool |
-- |
True if versions are compatible; raises Exception on major version mismatch
|
get_pytorch_version
| Name |
Type |
Description
|
| return |
Tuple[int, int, int] |
The (major, minor, patch) PyTorch version
|
check_pytorch_version
| Name |
Type |
Required |
Description
|
| min_major_version |
int |
Yes |
Minimum required major PyTorch version
|
| min_minor_version |
int |
Yes |
Minimum required minor PyTorch version
|
check_cuda_availability
| Name |
Type |
Description
|
| return |
bool |
True if CUDA is available via torch.cuda.is_available()
|
set_cuda_arch_list
| Name |
Type |
Required |
Description
|
| cuda_dir |
str |
Yes |
Path to the CUDA Toolkit directory
|
| return |
bool |
-- |
True if CUDA is available on the system; False if cross-compiling (arch list is set automatically)
|
get_cuda_cc_flag
| Name |
Type |
Description
|
| return |
List[str] |
A list of -gencode flags for the available GPU architectures (compute capability >= 6.0)
|
append_nvcc_threads
| Name |
Type |
Required |
Description
|
| nvcc_extra_args |
List[str] |
Yes |
Existing nvcc arguments to extend
|
| return |
List[str] |
-- |
The input args with "--threads 4" appended if CUDA >= 11.2
|
Supported CUDA Architectures (Cross-Compilation)
When cross-compiling (no GPU available), the following architectures are automatically configured:
| Architecture |
Compute Capability
|
| Pascal |
6.0, 6.1, 6.2
|
| Volta |
7.0
|
| Turing |
7.5
|
| Ampere |
8.0, 8.6 (CUDA 11.x only)
|
Usage Examples
from extensions.utils import (
check_system_pytorch_cuda_match,
check_pytorch_version,
get_cuda_cc_flag,
append_nvcc_threads,
)
# Validate environment before building
check_pytorch_version(min_major_version=1, min_minor_version=10)
check_system_pytorch_cuda_match("/usr/local/cuda")
# Get GPU architecture flags for compilation
cc_flags = get_cuda_cc_flag()
print(cc_flags)
# ['-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_80,code=sm_80']
# Add thread parallelism to nvcc args
nvcc_args = ["-O3", "--use_fast_math"]
nvcc_args = append_nvcc_threads(nvcc_args)
# ["-O3", "--use_fast_math", "--threads", "4"] (if CUDA >= 11.2)
Related Pages