Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Extension Utils

From Leeroopedia


Knowledge Sources
Domains Kernel Extensions, Build System, CUDA, Utilities
Last Updated 2026-02-09 00:00 GMT

Overview

Extension Utils is a collection of utility functions for validating CUDA environments, checking PyTorch version compatibility, generating CUDA compute capability flags, and configuring the build system for ColossalAI kernel extensions.

Description

This module provides essential build-system utilities for ColossalAI's extension compilation pipeline. It includes functions for retrieving CUDA versions from both the system (via nvcc) and the PyTorch build, validating that these versions match, checking minimum PyTorch version requirements, setting CUDA architecture lists for cross-compilation, generating GPU compute capability flags, and appending nvcc thread parallelism flags. A distributed-aware print_rank_0 function is also included to prevent duplicate output in multi-process environments.

Usage

Use these utility functions when building or configuring ColossalAI C++/CUDA kernel extensions. They are called internally by _CudaExtension and its subclasses to validate the build environment before compilation, but can also be used directly for custom build scripts or environment diagnostics.

Code Reference

Source Location

Signature

def print_rank_0(message: str) -> None:
    ...

def get_cuda_version_in_pytorch() -> List[int]:
    ...

def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
    ...

def check_system_pytorch_cuda_match(cuda_dir):
    ...

def get_pytorch_version() -> List[int]:
    ...

def check_pytorch_version(min_major_version, min_minor_version) -> bool:
    ...

def check_cuda_availability():
    ...

def set_cuda_arch_list(cuda_dir):
    ...

def get_cuda_cc_flag() -> List[str]:
    ...

def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
    ...

Import

from extensions.utils import (
    print_rank_0,
    get_cuda_version_in_pytorch,
    get_cuda_bare_metal_version,
    check_system_pytorch_cuda_match,
    get_pytorch_version,
    check_pytorch_version,
    check_cuda_availability,
    set_cuda_arch_list,
    get_cuda_cc_flag,
    append_nvcc_threads,
)

I/O Contract

Function Details

print_rank_0

Name Type Required Description
message str Yes Message to print only on rank 0 (or when torch.distributed is not initialized)

get_cuda_version_in_pytorch

Name Type Description
return Tuple[str, str] The (major, minor) CUDA version that PyTorch was compiled with

get_cuda_bare_metal_version

Name Type Required Description
cuda_dir str Yes Path to the CUDA Toolkit directory (CUDA_HOME)
return Tuple[str, str] -- The (major, minor) system CUDA version from nvcc

check_system_pytorch_cuda_match

Name Type Required Description
cuda_dir str Yes Path to the CUDA Toolkit directory
return bool -- True if versions are compatible; raises Exception on major version mismatch

get_pytorch_version

Name Type Description
return Tuple[int, int, int] The (major, minor, patch) PyTorch version

check_pytorch_version

Name Type Required Description
min_major_version int Yes Minimum required major PyTorch version
min_minor_version int Yes Minimum required minor PyTorch version

check_cuda_availability

Name Type Description
return bool True if CUDA is available via torch.cuda.is_available()

set_cuda_arch_list

Name Type Required Description
cuda_dir str Yes Path to the CUDA Toolkit directory
return bool -- True if CUDA is available on the system; False if cross-compiling (arch list is set automatically)

get_cuda_cc_flag

Name Type Description
return List[str] A list of -gencode flags for the available GPU architectures (compute capability >= 6.0)

append_nvcc_threads

Name Type Required Description
nvcc_extra_args List[str] Yes Existing nvcc arguments to extend
return List[str] -- The input args with "--threads 4" appended if CUDA >= 11.2

Supported CUDA Architectures (Cross-Compilation)

When cross-compiling (no GPU available), the following architectures are automatically configured:

Architecture Compute Capability
Pascal 6.0, 6.1, 6.2
Volta 7.0
Turing 7.5
Ampere 8.0, 8.6 (CUDA 11.x only)

Usage Examples

from extensions.utils import (
    check_system_pytorch_cuda_match,
    check_pytorch_version,
    get_cuda_cc_flag,
    append_nvcc_threads,
)

# Validate environment before building
check_pytorch_version(min_major_version=1, min_minor_version=10)
check_system_pytorch_cuda_match("/usr/local/cuda")

# Get GPU architecture flags for compilation
cc_flags = get_cuda_cc_flag()
print(cc_flags)
# ['-gencode', 'arch=compute_70,code=sm_70', '-gencode', 'arch=compute_80,code=sm_80']

# Add thread parallelism to nvcc args
nvcc_args = ["-O3", "--use_fast_math"]
nvcc_args = append_nvcc_threads(nvcc_args)
# ["-O3", "--use_fast_math", "--threads", "4"]  (if CUDA >= 11.2)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment