Implementation:Cleanlab Cleanlab Coteaching Train
| Knowledge Sources | |
|---|---|
| Domains | Deep Learning, Noisy Labels, Robust Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A PyTorch implementation of the Co-Teaching algorithm (Han et al., 2018) for training neural networks on noisily-labeled data, where two models teach each other by selecting clean samples.
Description
The coteaching module provides the complete training pipeline for the Co-Teaching algorithm. The core idea is to train two neural networks simultaneously, where each model identifies the examples it considers most likely to be correctly labeled (those with the lowest loss) and passes them to the other model for training. This cross-selection mechanism prevents both models from memorizing the same noisy labels. The module includes the co-teaching loss function, learning rate scheduling, forget rate scheduling, a full training loop, and an evaluation function.
Usage
Import this module when you need to train neural networks directly on noisily-labeled data without first removing the noisy labels. It is an alternative to cleanlab's primary approach of identifying and filtering label issues. It is designed to work with cleanlab.experimental.cifar_cnn for CIFAR-10 benchmarks but can be adapted for other PyTorch models and datasets. Requires PyTorch and a CUDA-enabled GPU.
Code Reference
Source Location
- Repository: Cleanlab
- File: cleanlab/experimental/coteaching.py
- Lines: 1-229
Key Functions
train
def train(
train_loader,
epoch,
model1,
optimizer1,
model2,
optimizer2,
args,
forget_rate_schedule,
class_weights,
accuracy,
):
loss_coteaching
def loss_coteaching(
y_1,
y_2,
t,
forget_rate,
class_weights=None,
):
forget_rate_scheduler
def forget_rate_scheduler(epochs, forget_rate, num_gradual, exponent):
initialize_lr_scheduler
def initialize_lr_scheduler(lr=0.001, epochs=250, epoch_decay_start=80):
Import
from cleanlab.experimental.coteaching import train, loss_coteaching, forget_rate_scheduler
I/O Contract
Inputs (train)
| Name | Type | Required | Description |
|---|---|---|---|
| train_loader | torch.utils.data.DataLoader | Yes | DataLoader providing (images, labels) batches |
| epoch | int | Yes | Current epoch number (0-indexed) |
| model1 | nn.Module | Yes | First PyTorch model with forward(self, x) method |
| optimizer1 | torch.optim.Adam | Yes | Optimizer for model1 |
| model2 | nn.Module | Yes | Second PyTorch model with forward(self, x) method |
| optimizer2 | torch.optim.Adam | Yes | Optimizer for model2 |
| args | argparse.Namespace | Yes | Must contain num_iter_per_epoch, print_freq, epochs, batch_size |
| forget_rate_schedule | np.ndarray | Yes | Array of forget rates per epoch from forget_rate_scheduler |
| class_weights | torch.Tensor or None | Yes | Class weights tensor of shape (num_classes,) or None |
| accuracy | callable | Yes | Function of form accuracy(output, target, topk=(1,)) |
Outputs (train)
| Name | Type | Description |
|---|---|---|
| train_acc1 | float | Training accuracy for model1 |
| train_acc2 | float | Training accuracy for model2 |
Inputs (loss_coteaching)
| Name | Type | Required | Description |
|---|---|---|---|
| y_1 | torch.Tensor | Yes | Output logits from model 1 |
| y_2 | torch.Tensor | Yes | Output logits from model 2 |
| t | torch.Tensor | Yes | Noisy target labels |
| forget_rate | float | Yes | Fraction of examples to forget (0 to 1), typically rate_schedule[epoch] |
| class_weights | torch.Tensor or None | No | Optional class weights of shape (num_classes,) |
Outputs (loss_coteaching)
| Name | Type | Description |
|---|---|---|
| loss_1 | torch.Tensor | Normalized loss for model 1 (trained on model 2's selected samples) |
| loss_2 | torch.Tensor | Normalized loss for model 2 (trained on model 1's selected samples) |
Implementation Details
Forget Rate Scheduling
The forget rate starts at 0 and gradually increases to the target rate over num_gradual epochs following the formula:
forget_rate_schedule[:num_gradual] = linspace(0, forget_rate^exponent, num_gradual)
After the gradual warmup period, the forget rate remains constant at the target rate.
Small Batch Handling
The training function skips the final batch if its size is below MINIMUM_BATCH_SIZE (16). This prevents noise amplification from gradient updates based on very few examples, which is especially important in noisy label settings.
Learning Rate Scheduling
Uses linear decay starting from epoch_decay_start (default 80), along with a beta1 change from 0.9 to 0.9 for the Adam optimizer.
Usage Examples
Basic Usage
from cleanlab.experimental.coteaching import (
train, evaluate, forget_rate_scheduler,
initialize_lr_scheduler, adjust_learning_rate,
)
from cleanlab.experimental.cifar_cnn import CNN
import torch
# Create two models
model1 = CNN(input_channel=3, n_outputs=10).cuda()
model2 = CNN(input_channel=3, n_outputs=10).cuda()
# Set up optimizers
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001)
# Create forget rate schedule
rate_schedule = forget_rate_scheduler(
epochs=250, forget_rate=0.2, num_gradual=10, exponent=1
)
# Create learning rate schedule
alpha_plan, beta1_plan = initialize_lr_scheduler(
lr=0.001, epochs=250, epoch_decay_start=80
)
# Training loop
for epoch in range(250):
adjust_learning_rate(optimizer1, epoch, alpha_plan, beta1_plan)
adjust_learning_rate(optimizer2, epoch, alpha_plan, beta1_plan)
train_acc1, train_acc2 = train(
train_loader, epoch, model1, optimizer1,
model2, optimizer2, args, rate_schedule,
class_weights=None, accuracy=accuracy_fn,
)
# Evaluate
acc1, acc2 = evaluate(test_loader, model1, model2)