Implementation:Scikit learn Scikit learn ArrayAPI
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Array Interoperability |
| Last Updated | 2026-02-08 15:00 GMT |
Overview
Concrete utility module for Array API compatibility provided by scikit-learn.
Description
The _array_api module provides tools to support the Python Array API standard across scikit-learn. It includes functions for yielding supported namespaces, handling device and dtype combinations, and ensuring interoperability between NumPy, CuPy, PyTorch, and array_api_strict backends. This enables scikit-learn estimators to work transparently with different array libraries.
Usage
Use these utilities when developing or testing scikit-learn estimators that need to operate across multiple array backends (e.g., NumPy, CuPy, PyTorch) via the Array API standard.
Code Reference
Source Location
- Repository: scikit-learn
- File: sklearn/utils/_array_api.py
Signature
def yield_namespaces(include_numpy_namespaces=True):
...
def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
...
def get_namespace(*arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT, xp=None):
...
def get_namespace_and_device(*arrays, remove_none=True, remove_types=REMOVE_TYPES_DEFAULT):
...
Import
from sklearn.utils._array_api import get_namespace, get_namespace_and_device
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| arrays | array-like | Yes | One or more arrays to determine the namespace for |
| include_numpy_namespaces | bool | No | Whether to include numpy namespaces in test yields |
| remove_none | bool | No | Whether to remove None values from arrays before detection |
| xp | module | No | Explicitly specify the array namespace to use |
Outputs
| Name | Type | Description |
|---|---|---|
| xp | module | The detected or specified array API namespace |
| is_array_api | bool | Whether the namespace is an Array API namespace |
Usage Examples
Basic Usage
import numpy as np
from sklearn.utils._array_api import get_namespace
X = np.array([[1, 2], [3, 4]])
xp, is_array_api = get_namespace(X)
print(xp) # numpy module
print(is_array_api) # False for standard numpy