Implementation:Scikit learn Scikit learn Make Column Selector
Overview
Concrete tool for creating callable column selectors for ColumnTransformer provided by scikit-learn.
Code Reference
Class: make_column_selector
Module: sklearn/compose/_column_transformer.py (lines 1516-1602)
Constructor signature:
class make_column_selector:
def __init__(self, pattern=None, *, dtype_include=None, dtype_exclude=None):
Callable signature:
def __call__(self, df):
Despite being named in lowercase (following the convention of factory functions), make_column_selector is actually a class. Instantiating it returns a callable object that can be passed directly to ColumnTransformer in place of an explicit column list.
I/O Contract
Constructor parameters:
pattern:str, default=None-- A regex pattern. Columns whose names match this pattern are included. IfNone, no name-based filtering is applied.dtype_include: column dtype or list of column dtypes, default=None-- Dtypes to include in the selection. Passed topandas.DataFrame.select_dtypes(include=...).dtype_exclude: column dtype or list of column dtypes, default=None-- Dtypes to exclude from the selection. Passed topandas.DataFrame.select_dtypes(exclude=...).
Call input:
df: pandas DataFrame -- The input DataFrame to select columns from. RaisesValueErrorif the input does not have anilocattribute (i.e., is not a DataFrame).
Call output:
listof column names -- The selected column names as a Python list.
Selection logic: When multiple criteria are specified, all criteria must match for a column to be selected. The function first applies dtype filtering, then applies the regex pattern to the surviving columns.
Implementation Details
The __call__ method performs column selection in the following steps:
- Takes a single-row slice of the DataFrame (
df.iloc[:1]) to minimize memory usage during dtype inspection. - If
dtype_includeordtype_excludeis set, callsselect_dtypeson the single-row slice to filter columns by type. - If
patternis set, appliescols.str.contains(self.pattern, regex=True)to filter the remaining columns by name. - Returns the final column index as a Python list via
cols.tolist().
Usage Examples
import numpy as np
import pandas as pd
from sklearn.compose import make_column_selector, make_column_transformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
# Create a mixed-type DataFrame
X = pd.DataFrame({
"city": ["London", "London", "Paris", "Sallisaw"],
"rating": [5, 3, 4, 5]
})
# Select numeric columns by dtype
num_selector = make_column_selector(dtype_include=np.number)
print(num_selector(X)) # ['rating']
# Select categorical columns by dtype
cat_selector = make_column_selector(dtype_include=[object, "string"])
print(cat_selector(X)) # ['city']
# Use selectors inside a ColumnTransformer
ct = make_column_transformer(
(StandardScaler(), make_column_selector(dtype_include=np.number)),
(OneHotEncoder(), make_column_selector(dtype_include=[object, "string"]))
)
ct.fit_transform(X)