Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Cleanlab Cleanlab CNN Classifier For MNIST

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Computer Vision, Data Quality
Last Updated 2026-02-09 00:00 GMT

Overview

A cleanlab-compatible PyTorch CNN classifier wrapped in the scikit-learn estimator interface for MNIST and sklearn-digits image classification.

Description

The mnist_pytorch module provides two key classes: SimpleNet, a basic PyTorch convolutional neural network (two convolutional layers with dropout, two fully connected layers, outputting log-softmax probabilities), and CNN, a scikit-learn BaseEstimator wrapper around SimpleNet that implements fit(), predict(), and predict_proba() methods. The CNN class handles data loading, training via SGD with NLL loss, and inference internally. It supports both the full MNIST dataset (via torchvision) and the smaller sklearn-digits dataset. The module also includes dataset helper functions get_mnist_dataset and get_sklearn_digits_dataset for loading and transforming image data. This serves as a reference implementation demonstrating how to make any custom PyTorch model compatible with cleanlab's sklearn-based API.

Usage

Import this module when you need a ready-made cleanlab-compatible image classifier for MNIST-like data, when testing cleanlab's cross-validation and label issue detection on image datasets, or as a reference for wrapping your own PyTorch model to work with cleanlab.

Code Reference

Source Location

  • Repository: Cleanlab
  • File: cleanlab/experimental/mnist_pytorch.py
  • Lines: 1-369

Signature

class CNN(BaseEstimator):
    def __init__(
        self,
        batch_size=64,
        epochs=6,
        log_interval=50,
        lr=0.01,
        momentum=0.5,
        no_cuda=False,
        seed=1,
        test_batch_size=None,
        dataset="mnist",
        loader=None,
    )
class SimpleNet(nn.Module):
    def __init__(self)
    def forward(self, x, T=1.0)

Import

from cleanlab.experimental.mnist_pytorch import CNN, SimpleNet

I/O Contract

Inputs (CNN.__init__)

Name Type Required Description
batch_size int No Training batch size. Default 64.
epochs int No Number of training epochs. Default 6.
log_interval int No Batches between printed log messages. Set to None to suppress. Default 50.
lr float No Learning rate for SGD. Default 0.01.
momentum float No SGD momentum. Default 0.5.
no_cuda bool No If True, disables CUDA. Default False.
seed int No Random seed. Default 1.
test_batch_size int No Batch size for test/prediction. Defaults to test set size.
dataset str No Dataset to use: "mnist" or "sklearn-digits". Default "mnist".
loader str No Force loader to "train" or "test" for all operations. Default None.

Inputs (CNN.fit)

Name Type Required Description
train_idx np.ndarray Yes Array of indices specifying which examples to use for training.
train_labels np.ndarray No Array of labels corresponding to train_idx. If None, uses dataset's built-in labels.
sample_weight np.ndarray No Per-sample weights for class-weighted loss.
loader str No "train" or "test" to select which dataset split to load. Default "train".

Outputs

Name Type Description
predict() return np.ndarray 1D array of predicted integer class labels.
predict_proba() return np.ndarray 2D array of shape (N, K) with predicted class probabilities. Obtained by exponentiating log-softmax outputs.

Architecture

The SimpleNet CNN architecture consists of:

Layer Details
Conv1 1 input channel, 10 output channels, 5x5 kernel, followed by ReLU and 2x2 max pooling
Conv2 10 input channels, 20 output channels, 5x5 kernel, with Dropout2d, followed by ReLU and 2x2 max pooling
FC1 320 inputs, 50 outputs, ReLU activation, dropout
FC2 50 inputs, 10 outputs (one per digit class)
Output Log-softmax over 10 classes

Usage Examples

Basic Usage: Train and Predict

import numpy as np
from cleanlab.experimental.mnist_pytorch import CNN

# Use the smaller sklearn-digits dataset
clf = CNN(dataset="sklearn-digits", epochs=3, log_interval=None)

# Train on all training indices
train_idx = np.arange(1247)
clf.fit(train_idx)

# Predict probabilities on test set
test_idx = np.arange(550)
pred_probs = clf.predict_proba(test_idx, loader="test")
predictions = clf.predict(test_idx, loader="test")
print(f"Predictions shape: {predictions.shape}")
print(f"Pred probs shape: {pred_probs.shape}")

Usage with Cleanlab

from sklearn.model_selection import cross_val_predict
from cleanlab.experimental.mnist_pytorch import CNN
from cleanlab.filter import find_label_issues

# The CNN class is sklearn-compatible, so it works with cross_val_predict
clf = CNN(dataset="sklearn-digits", epochs=3, log_interval=None)

# Use cross-validation to get out-of-sample predicted probabilities
# then pass them to cleanlab's find_label_issues

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment