Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Microsoft DeepSpeedExamples Classification Evaluation

From Leeroopedia


Metadata

Field Value
Page Type Principle
Repository Microsoft/DeepSpeedExamples
Title Classification_Evaluation
Domains Deep_Learning, Evaluation, Computer_Vision
Related Implementation Implementation:Microsoft_DeepSpeedExamples_Test_Function_CIFAR

Overview

A standard evaluation methodology that measures overall and per-class accuracy on held-out test data for image classification models.

Description

After training a classification model, it must be evaluated on a held-out test set to measure generalization performance. The evaluation methodology in the CIFAR-10 workflow computes two complementary metrics:

  1. Overall Accuracy -- The fraction of correctly classified samples across the entire test set. This gives a single summary metric of model quality.
  2. Per-Class Accuracy -- The fraction of correctly classified samples within each class individually. This identifies class-specific weaknesses that overall accuracy may mask.

In the context of distributed training with DeepSpeed, evaluation has additional considerations:

  • Mixed Precision Handling -- Test inputs must be cast to the same dtype (fp16/bf16) used during training to match the model's expected input format
  • Distributed Execution -- All ranks execute the evaluation loop, but only rank 0 reports results to avoid duplicate output
  • Model Mode -- The model must be set to evaluation mode (model.eval()) to disable training-specific behaviors like dropout
  • Gradient Disabling -- torch.no_grad() disables gradient computation during evaluation, saving memory and compute

Theoretical Basis

Overall Accuracy

Accuracy = correct_predictions / total_samples

For CIFAR-10 with 10,000 test images, random chance would yield 10% accuracy. A well-trained baseline CNN typically achieves 50-60% accuracy after 2 epochs, or 70-80% after 30 epochs.

Per-Class Accuracy

Accuracy_class_i = correct_predictions_class_i / total_samples_class_i

Per-class accuracy reveals important patterns:

  • Class imbalance effects -- Some classes may be inherently harder to distinguish (e.g., "cat" vs "dog" vs "deer" in CIFAR-10)
  • Model biases -- The model may learn to favor certain classes over others
  • Confusion patterns -- Classes with low accuracy suggest the model confuses them with visually similar classes

Evaluation in Distributed Settings

In distributed training, each rank holds a copy of the model (or a shard in ZeRO Stage 3). During evaluation:

  • Each rank runs inference on the full test set (test data is not distributed)
  • The test DataLoader does not use a DistributedSampler -- all ranks see the same data
  • Only rank 0 prints results to avoid N-fold duplicate output
  • This differs from training, where data is distributed across ranks

torch.no_grad() Context

The torch.no_grad() context manager serves two purposes during evaluation:

  1. Memory efficiency -- Disables gradient tracking, freeing the memory that would otherwise store intermediate activations for backpropagation
  2. Computational efficiency -- Skips the construction of the autograd computation graph, reducing overhead

This is critical for evaluation because:

  • No parameter updates occur during evaluation (no backward() or step())
  • The saved memory allows larger batch sizes or simply reduces peak memory usage
  • The computational savings can be significant for large models

Mixed Precision During Evaluation

When the model was trained with mixed precision (fp16 or bf16), evaluation inputs must be cast to the matching dtype:

Training with fp16:
    model weights are in fp16 (or fp32 master weights with fp16 compute)
    evaluation inputs must be cast to fp16 before forward pass

Training with bf16:
    model weights are in bf16
    evaluation inputs must be cast to bf16 before forward pass

Training with fp32:
    no conversion needed (target_dtype = None)

Failure to match dtypes would cause a runtime error or produce incorrect results.

Evaluation Metrics for CIFAR-10

The 10 CIFAR-10 classes evaluated are:

Index Class Name Typical Difficulty
0 plane Medium
1 car Easy (distinctive shape)
2 bird Hard (varied poses)
3 cat Hard (similar to dog/deer)
4 deer Medium
5 dog Hard (similar to cat)
6 frog Easy (distinctive color)
7 horse Medium
8 ship Easy (distinctive context)
9 truck Medium (similar to car)

Evaluation Pattern

The standard evaluation pattern used in both the baseline and DeepSpeed versions:

# Set model to evaluation mode
model.eval()

correct, total = 0, 0
class_correct = [0.0 for _ in range(10)]
class_total = [0.0 for _ in range(10)]

with torch.no_grad():
    for data in testloader:
        images, labels = data
        # Cast to target dtype if using mixed precision
        if target_dtype is not None:
            images = images.to(target_dtype)
        outputs = model(images.to(device))
        _, predicted = torch.max(outputs.data, 1)

        # Overall accuracy
        total += labels.size(0)
        correct += (predicted == labels.to(device)).sum().item()

        # Per-class accuracy
        batch_correct = (predicted == labels.to(device)).squeeze()
        for i in range(batch_size):
            label = labels[i]
            class_correct[label] += batch_correct[i].item()
            class_total[label] += 1

# Report (only on rank 0 in distributed setting)
print(f"Overall accuracy: {100 * correct / total:.0f}%")
for i in range(10):
    print(f"Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.0f}%")

Baseline vs DeepSpeed Evaluation Differences

Aspect Baseline (cifar10_tutorial.py) DeepSpeed (cifar10_deepspeed.py)
Model wrapper Raw Net instance DeepSpeedEngine wrapping Net
eval() call net.eval() (implicit in script flow) model_engine.eval()
Dtype handling None (fp32 only) Cast to target_dtype (fp16/bf16/None)
Device handling Manual .to(device) .to(local_device) from engine
Rank filtering Not applicable (single process) if model_engine.local_rank == 0
Test function Inline code Extracted test() function

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment