Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Cleanlab Cleanlab Coteaching Train

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Noisy Labels, Robust Training
Last Updated 2026-02-09 00:00 GMT

Overview

A PyTorch implementation of the Co-Teaching algorithm (Han et al., 2018) for training neural networks on noisily-labeled data, where two models teach each other by selecting clean samples.

Description

The coteaching module provides the complete training pipeline for the Co-Teaching algorithm. The core idea is to train two neural networks simultaneously, where each model identifies the examples it considers most likely to be correctly labeled (those with the lowest loss) and passes them to the other model for training. This cross-selection mechanism prevents both models from memorizing the same noisy labels. The module includes the co-teaching loss function, learning rate scheduling, forget rate scheduling, a full training loop, and an evaluation function.

Usage

Import this module when you need to train neural networks directly on noisily-labeled data without first removing the noisy labels. It is an alternative to cleanlab's primary approach of identifying and filtering label issues. It is designed to work with cleanlab.experimental.cifar_cnn for CIFAR-10 benchmarks but can be adapted for other PyTorch models and datasets. Requires PyTorch and a CUDA-enabled GPU.

Code Reference

Source Location

  • Repository: Cleanlab
  • File: cleanlab/experimental/coteaching.py
  • Lines: 1-229

Key Functions

train

def train(
    train_loader,
    epoch,
    model1,
    optimizer1,
    model2,
    optimizer2,
    args,
    forget_rate_schedule,
    class_weights,
    accuracy,
):

loss_coteaching

def loss_coteaching(
    y_1,
    y_2,
    t,
    forget_rate,
    class_weights=None,
):

forget_rate_scheduler

def forget_rate_scheduler(epochs, forget_rate, num_gradual, exponent):

initialize_lr_scheduler

def initialize_lr_scheduler(lr=0.001, epochs=250, epoch_decay_start=80):

Import

from cleanlab.experimental.coteaching import train, loss_coteaching, forget_rate_scheduler

I/O Contract

Inputs (train)

Name Type Required Description
train_loader torch.utils.data.DataLoader Yes DataLoader providing (images, labels) batches
epoch int Yes Current epoch number (0-indexed)
model1 nn.Module Yes First PyTorch model with forward(self, x) method
optimizer1 torch.optim.Adam Yes Optimizer for model1
model2 nn.Module Yes Second PyTorch model with forward(self, x) method
optimizer2 torch.optim.Adam Yes Optimizer for model2
args argparse.Namespace Yes Must contain num_iter_per_epoch, print_freq, epochs, batch_size
forget_rate_schedule np.ndarray Yes Array of forget rates per epoch from forget_rate_scheduler
class_weights torch.Tensor or None Yes Class weights tensor of shape (num_classes,) or None
accuracy callable Yes Function of form accuracy(output, target, topk=(1,))

Outputs (train)

Name Type Description
train_acc1 float Training accuracy for model1
train_acc2 float Training accuracy for model2

Inputs (loss_coteaching)

Name Type Required Description
y_1 torch.Tensor Yes Output logits from model 1
y_2 torch.Tensor Yes Output logits from model 2
t torch.Tensor Yes Noisy target labels
forget_rate float Yes Fraction of examples to forget (0 to 1), typically rate_schedule[epoch]
class_weights torch.Tensor or None No Optional class weights of shape (num_classes,)

Outputs (loss_coteaching)

Name Type Description
loss_1 torch.Tensor Normalized loss for model 1 (trained on model 2's selected samples)
loss_2 torch.Tensor Normalized loss for model 2 (trained on model 1's selected samples)

Implementation Details

Forget Rate Scheduling

The forget rate starts at 0 and gradually increases to the target rate over num_gradual epochs following the formula:

forget_rate_schedule[:num_gradual] = linspace(0, forget_rate^exponent, num_gradual)

After the gradual warmup period, the forget rate remains constant at the target rate.

Small Batch Handling

The training function skips the final batch if its size is below MINIMUM_BATCH_SIZE (16). This prevents noise amplification from gradient updates based on very few examples, which is especially important in noisy label settings.

Learning Rate Scheduling

Uses linear decay starting from epoch_decay_start (default 80), along with a beta1 change from 0.9 to 0.9 for the Adam optimizer.

Usage Examples

Basic Usage

from cleanlab.experimental.coteaching import (
    train, evaluate, forget_rate_scheduler,
    initialize_lr_scheduler, adjust_learning_rate,
)
from cleanlab.experimental.cifar_cnn import CNN
import torch

# Create two models
model1 = CNN(input_channel=3, n_outputs=10).cuda()
model2 = CNN(input_channel=3, n_outputs=10).cuda()

# Set up optimizers
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001)

# Create forget rate schedule
rate_schedule = forget_rate_scheduler(
    epochs=250, forget_rate=0.2, num_gradual=10, exponent=1
)

# Create learning rate schedule
alpha_plan, beta1_plan = initialize_lr_scheduler(
    lr=0.001, epochs=250, epoch_decay_start=80
)

# Training loop
for epoch in range(250):
    adjust_learning_rate(optimizer1, epoch, alpha_plan, beta1_plan)
    adjust_learning_rate(optimizer2, epoch, alpha_plan, beta1_plan)
    train_acc1, train_acc2 = train(
        train_loader, epoch, model1, optimizer1,
        model2, optimizer2, args, rate_schedule,
        class_weights=None, accuracy=accuracy_fn,
    )

# Evaluate
acc1, acc2 = evaluate(test_loader, model1, model2)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment