Implementation:Cleanlab Cleanlab CNN Classifier For MNIST
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Computer Vision, Data Quality |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A cleanlab-compatible PyTorch CNN classifier wrapped in the scikit-learn estimator interface for MNIST and sklearn-digits image classification.
Description
The mnist_pytorch module provides two key classes: SimpleNet, a basic PyTorch convolutional neural network (two convolutional layers with dropout, two fully connected layers, outputting log-softmax probabilities), and CNN, a scikit-learn BaseEstimator wrapper around SimpleNet that implements fit(), predict(), and predict_proba() methods. The CNN class handles data loading, training via SGD with NLL loss, and inference internally. It supports both the full MNIST dataset (via torchvision) and the smaller sklearn-digits dataset. The module also includes dataset helper functions get_mnist_dataset and get_sklearn_digits_dataset for loading and transforming image data. This serves as a reference implementation demonstrating how to make any custom PyTorch model compatible with cleanlab's sklearn-based API.
Usage
Import this module when you need a ready-made cleanlab-compatible image classifier for MNIST-like data, when testing cleanlab's cross-validation and label issue detection on image datasets, or as a reference for wrapping your own PyTorch model to work with cleanlab.
Code Reference
Source Location
- Repository: Cleanlab
- File: cleanlab/experimental/mnist_pytorch.py
- Lines: 1-369
Signature
class CNN(BaseEstimator):
def __init__(
self,
batch_size=64,
epochs=6,
log_interval=50,
lr=0.01,
momentum=0.5,
no_cuda=False,
seed=1,
test_batch_size=None,
dataset="mnist",
loader=None,
)
class SimpleNet(nn.Module):
def __init__(self)
def forward(self, x, T=1.0)
Import
from cleanlab.experimental.mnist_pytorch import CNN, SimpleNet
I/O Contract
Inputs (CNN.__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| batch_size | int | No | Training batch size. Default 64. |
| epochs | int | No | Number of training epochs. Default 6. |
| log_interval | int | No | Batches between printed log messages. Set to None to suppress. Default 50. |
| lr | float | No | Learning rate for SGD. Default 0.01. |
| momentum | float | No | SGD momentum. Default 0.5. |
| no_cuda | bool | No | If True, disables CUDA. Default False. |
| seed | int | No | Random seed. Default 1. |
| test_batch_size | int | No | Batch size for test/prediction. Defaults to test set size. |
| dataset | str | No | Dataset to use: "mnist" or "sklearn-digits". Default "mnist". |
| loader | str | No | Force loader to "train" or "test" for all operations. Default None. |
Inputs (CNN.fit)
| Name | Type | Required | Description |
|---|---|---|---|
| train_idx | np.ndarray | Yes | Array of indices specifying which examples to use for training. |
| train_labels | np.ndarray | No | Array of labels corresponding to train_idx. If None, uses dataset's built-in labels. |
| sample_weight | np.ndarray | No | Per-sample weights for class-weighted loss. |
| loader | str | No | "train" or "test" to select which dataset split to load. Default "train". |
Outputs
| Name | Type | Description |
|---|---|---|
| predict() return | np.ndarray | 1D array of predicted integer class labels. |
| predict_proba() return | np.ndarray | 2D array of shape (N, K) with predicted class probabilities. Obtained by exponentiating log-softmax outputs. |
Architecture
The SimpleNet CNN architecture consists of:
| Layer | Details |
|---|---|
| Conv1 | 1 input channel, 10 output channels, 5x5 kernel, followed by ReLU and 2x2 max pooling |
| Conv2 | 10 input channels, 20 output channels, 5x5 kernel, with Dropout2d, followed by ReLU and 2x2 max pooling |
| FC1 | 320 inputs, 50 outputs, ReLU activation, dropout |
| FC2 | 50 inputs, 10 outputs (one per digit class) |
| Output | Log-softmax over 10 classes |
Usage Examples
Basic Usage: Train and Predict
import numpy as np
from cleanlab.experimental.mnist_pytorch import CNN
# Use the smaller sklearn-digits dataset
clf = CNN(dataset="sklearn-digits", epochs=3, log_interval=None)
# Train on all training indices
train_idx = np.arange(1247)
clf.fit(train_idx)
# Predict probabilities on test set
test_idx = np.arange(550)
pred_probs = clf.predict_proba(test_idx, loader="test")
predictions = clf.predict(test_idx, loader="test")
print(f"Predictions shape: {predictions.shape}")
print(f"Pred probs shape: {pred_probs.shape}")
Usage with Cleanlab
from sklearn.model_selection import cross_val_predict
from cleanlab.experimental.mnist_pytorch import CNN
from cleanlab.filter import find_label_issues
# The CNN class is sklearn-compatible, so it works with cross_val_predict
clf = CNN(dataset="sklearn-digits", epochs=3, log_interval=None)
# Use cross-validation to get out-of-sample predicted probabilities
# then pass them to cleanlab's find_label_issues