Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Scikit learn Scikit learn LinearClassifierMixin Predict

From Leeroopedia


Field Value
source scikit-learn|https://github.com/scikit-learn/scikit-learn
domains Data_Science, Machine_Learning
last_updated 2026-02-08 15:00 GMT

Overview

Concrete tool for generating class predictions from a fitted linear classifier provided by scikit-learn.

Description

The LinearClassifierMixin.predict method is the shared prediction implementation used by all linear classifiers in scikit-learn, including LogisticRegression, SGDClassifier, RidgeClassifier, and Perceptron. It computes class labels by first evaluating the decision function (a linear combination of features and learned weights) and then selecting the class with the highest score.

The method implements two code paths:

  • Binary classification -- When the decision function returns a 1-D array of scores, the predicted class is determined by thresholding at zero: scores > 0 map to classes_[1], and scores <= 0 map to classes_[0].
  • Multiclass classification -- When the decision function returns a 2-D array (one score per class), the predicted class is the one with the maximum score (argmax along the class axis).

Usage

  • Generating class label predictions on test or validation data after fitting a linear classifier.
  • Producing batch predictions for deployment or evaluation.
  • Called internally by scoring methods and meta-estimators such as cross-validation and grid search.

Code Reference

Source Location

sklearn/linear_model/_base.py, method LinearClassifierMixin.predict (class LinearClassifierMixin inherits from ClassifierMixin)

Signature

def predict(self, X):

Full Source

def predict(self, X):
    """
    Predict class labels for samples in X.

    Parameters
    ----------
    X : {array-like, sparse matrix} of shape (n_samples, n_features)
        The data matrix for which we want to get the predictions.

    Returns
    -------
    y_pred : ndarray of shape (n_samples,)
        Vector containing the class labels for each sample.
    """
    xp, _ = get_namespace(X)
    scores = self.decision_function(X)
    if len(scores.shape) == 1:
        indices = xp.astype(scores > 0, indexing_dtype(xp))
    else:
        indices = xp.argmax(scores, axis=1)

    return xp.take(self.classes_, indices, axis=0)

Import

from sklearn.linear_model import LogisticRegression

clf = LogisticRegression()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

I/O Contract

Inputs

Parameter Type Default Description
X {array-like, sparse matrix} of shape (n_samples, n_features) (required) Feature matrix for which predictions are requested. Must have the same number of features as the training data.

Preconditions

  • The estimator must have been fitted (i.e., fit must have been called). The method relies on the fitted attributes coef_, intercept_, and classes_.
  • The input X must have the same number of columns (features) as the training data.

Outputs

Return Type Description
y_pred ndarray of shape (n_samples,) Predicted class labels for each sample. The labels are drawn from self.classes_, preserving the original label dtype and values.

Usage Examples

Predicting on test data:

from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

clf = LogisticRegression(max_iter=200, random_state=42)
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)
print(y_pred[:10])  # array of predicted class labels

Comparing predictions with decision function scores:

import numpy as np

scores = clf.decision_function(X_test)
print(scores.shape)  # (45, 3) for 3-class problem

# Manual argmax should match predict output
y_manual = clf.classes_[np.argmax(scores, axis=1)]
assert np.array_equal(y_pred, y_manual)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment