Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Cleanlab Cleanlab CIFAR CNN

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Image Classification, Noisy Labels
Last Updated 2026-02-09 00:00 GMT

Overview

A PyTorch CNN class designed as a baseline architecture for CIFAR-10 image classification, particularly suited for learning with noisy labels via co-teaching.

Description

The CNN class is a PyTorch neural network module (nn.Module) that implements a 9-layer convolutional neural network architecture proven effective for CIFAR-10 benchmarking. The architecture is organized into three blocks of three convolutional layers each, with progressively changing channel dimensions (128 -> 256 -> 512 -> 256 -> 128). Each convolutional layer is followed by batch normalization and leaky ReLU activation. Max pooling with dropout is applied between the first two blocks, and global average pooling is used at the end before a single fully-connected layer that maps the 128-dimensional feature vector to the output class logits. The code is adapted from the Co-teaching paper reference implementation.

Usage

Import this class when you need a proven CNN baseline for CIFAR-10 image classification tasks, especially when working with noisily-labeled data. It is intended to be used alongside cleanlab.experimental.coteaching for the co-teaching training procedure. Requires PyTorch to be installed.

Code Reference

Source Location

  • Repository: Cleanlab
  • File: cleanlab/experimental/cifar_cnn.py
  • Lines: 34-108

Signature

class CNN(nn.Module):
    def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25, top_bn=False):
        ...

    def forward(self, x):
        ...

Import

from cleanlab.experimental.cifar_cnn import CNN

I/O Contract

Inputs (Constructor)

Name Type Required Description
input_channel int No (default: 3) Number of input channels in the image (e.g., 3 for RGB)
n_outputs int No (default: 10) Number of output classes for classification
dropout_rate float No (default: 0.25) Dropout probability applied after max pooling layers
top_bn bool No (default: False) Whether to apply batch normalization to the final logits

Inputs (Forward)

Name Type Required Description
x torch.Tensor Yes Input image tensor of shape (batch_size, input_channel, height, width); expected 32x32 for CIFAR-10

Outputs

Name Type Description
logit torch.Tensor Output logits tensor of shape (batch_size, n_outputs) representing unnormalized class scores

Architecture Details

The network consists of three convolutional blocks:

  • Block 1 (128 channels): Three Conv2d layers (3x3, stride 1, padding 1) with batch normalization and leaky ReLU (slope 0.01), followed by 2x2 max pooling and 2D dropout.
  • Block 2 (256 channels): Three Conv2d layers (3x3, stride 1, padding 1) with batch normalization and leaky ReLU, followed by 2x2 max pooling and 2D dropout.
  • Block 3 (512 -> 256 -> 128 channels): Three Conv2d layers (3x3, stride 1, no padding) with batch normalization and leaky ReLU, followed by global average pooling.
  • Classifier: A single fully-connected layer mapping 128 features to n_outputs logits.

Usage Examples

Basic Usage

import torch
from cleanlab.experimental.cifar_cnn import CNN

# Create a CNN for CIFAR-10 (3-channel RGB, 10 classes)
model = CNN(input_channel=3, n_outputs=10, dropout_rate=0.25)

# Example forward pass with a batch of 32x32 images
images = torch.randn(16, 3, 32, 32)
logits = model(images)
print(logits.shape)  # torch.Size([16, 10])

Usage with Co-Teaching

from cleanlab.experimental.cifar_cnn import CNN

# Create two models for co-teaching
model1 = CNN(input_channel=3, n_outputs=10)
model2 = CNN(input_channel=3, n_outputs=10)

# Move to GPU for training with coteaching.train()
model1.cuda()
model2.cuda()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment