Implementation:Microsoft DeepSpeedExamples Test Function CIFAR
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation |
| Repository | Microsoft/DeepSpeedExamples |
| Title | Test_Function_CIFAR |
| Type | Function Doc |
| Source File | training/cifar/cifar10_deepspeed.py
|
| Lines | 212-277 |
| Implements | Principle:Microsoft_DeepSpeedExamples_Classification_Evaluation |
Overview
Concrete tool for evaluating CIFAR-10 classification accuracy with DeepSpeed mixed-precision support.
Description
The test() function in cifar10_deepspeed.py evaluates a trained DeepSpeed model on the CIFAR-10 test set. It computes both overall accuracy and per-class accuracy for all 10 CIFAR-10 classes, handling mixed-precision dtype casting and distributed rank-aware reporting.
The function performs the following steps:
- Defines the CIFAR-10 class names for human-readable reporting
- Creates a test DataLoader from the provided test dataset (non-distributed, no shuffling)
- Sets the model engine to evaluation mode via
model_engine.eval() - Iterates over the test set under
torch.no_grad(), computing predictions - Casts input images to the target dtype (fp16/bf16) if mixed precision was used during training
- Accumulates correct/total counts for both overall and per-class accuracy
- Reports results only on rank 0 to avoid duplicate output in distributed settings
This function is called at the end of training in main() at line 397:
test(model_engine, testset, local_device, target_dtype)
Code Reference
File: training/cifar/cifar10_deepspeed.py, Lines 212-277
def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
"""Test the network on the test data.
Args:
model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
testset (torch.utils.data.Dataset): the test dataset.
local_device (str): the local device name.
target_dtype (torch.dtype): the target datatype for the test data.
test_batch_size (int): the test batch size.
"""
# The 10 classes for CIFAR10.
classes = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
# Define the test dataloader.
testloader = torch.utils.data.DataLoader(
testset, batch_size=test_batch_size, shuffle=False, num_workers=0
)
# For total accuracy.
correct, total = 0, 0
# For accuracy per class.
class_correct = list(0.0 for i in range(10))
class_total = list(0.0 for i in range(10))
# Start testing.
model_engine.eval()
with torch.no_grad():
for data in testloader:
images, labels = data
if target_dtype != None:
images = images.to(target_dtype)
outputs = model_engine(images.to(local_device))
_, predicted = torch.max(outputs.data, 1)
# Count the total accuracy.
total += labels.size(0)
correct += (predicted == labels.to(local_device)).sum().item()
# Count the accuracy per class.
batch_correct = (predicted == labels.to(local_device)).squeeze()
for i in range(test_batch_size):
label = labels[i]
class_correct[label] += batch_correct[i].item()
class_total[label] += 1
if model_engine.local_rank == 0:
print(
f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
)
# For all classes, print the accuracy.
for i in range(10):
print(
f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
)
Signature
def test(
model_engine: deepspeed.runtime.engine.DeepSpeedEngine,
testset: torch.utils.data.Dataset,
local_device: str,
target_dtype: Optional[torch.dtype],
test_batch_size: int = 4,
) -> None:
"""Evaluate the CIFAR-10 model and print accuracy metrics.
Computes overall and per-class accuracy on the test set.
Prints results only on rank 0 in distributed settings.
"""
I/O Contract
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
| model_engine | DeepSpeedEngine |
(required) | The trained DeepSpeed engine wrapping the Net model. Must support .eval(), .local_rank, and forward pass __call__.
|
| testset | torch.utils.data.Dataset |
(required) | CIFAR-10 test dataset (10,000 images). Loaded via torchvision.datasets.CIFAR10(train=False).
|
| local_device | str |
(required) | Device name string (e.g., "cuda:0", "xpu:0") for placing test data on the correct device.
|
| target_dtype | torch.dtype or None |
(required) | Data type for mixed precision: torch.half (fp16), torch.bfloat16 (bf16), or None (fp32, no conversion).
|
| test_batch_size | int |
4 | Number of images per test batch. Must evenly divide the test set size for accurate per-class counting. |
Outputs
| Output | Type | Description |
|---|---|---|
| Return value | None |
The function has no return value |
| Printed (rank 0 only) | stdout | Overall accuracy: "Accuracy of the network on the N test images: XX %" |
| Printed (rank 0 only) | stdout | Per-class accuracy: "Accuracy of <class> : XX %" for each of 10 classes |
Internal Data Flow
testset
|
v
DataLoader(batch_size=4, shuffle=False, num_workers=0)
|
v
for (images, labels) in testloader:
|
+-- images.to(target_dtype) [if target_dtype is not None]
|
+-- images.to(local_device) [move to GPU]
|
+-- model_engine(images) [forward pass]
| |
| v
| outputs: (B, 10) [raw logits]
| |
| v
+-- torch.max(outputs, 1) [argmax for predictions]
| |
| v
+-- predicted == labels [comparison for accuracy]
| |
+-------+-- correct += sum [overall accumulator]
| |
+-------+-- class_correct[i] [per-class accumulator]
|
v
if local_rank == 0:
print(overall accuracy)
print(per-class accuracy for all 10 classes)
Test DataLoader Configuration
| Parameter | Value | Rationale |
|---|---|---|
| batch_size | 4 (default) | Small batch for consistent per-class counting |
| shuffle | False | Deterministic evaluation order |
| num_workers | 0 | Single-process data loading (avoids multiprocessing issues) |
Note that unlike the training DataLoader (created by deepspeed.initialize() with DistributedSampler), the test DataLoader is created manually without distributed sampling. Every rank evaluates the full test set.
Usage Example
# Called at the end of main() after training completes:
def main(args):
# ... training code ...
########################################################################
# Step 4. Test the network on the test data.
########################################################################
test(model_engine, testset, local_device, target_dtype)
# With a custom test batch size:
test(model_engine, testset, local_device, target_dtype, test_batch_size=32)
Expected Output Format
Accuracy of the network on the 10000 test images: 53 % Accuracy of plane : 63 % Accuracy of car : 70 % Accuracy of bird : 34 % Accuracy of cat : 36 % Accuracy of deer : 42 % Accuracy of dog : 52 % Accuracy of frog : 68 % Accuracy of horse : 51 % Accuracy of ship : 72 % Accuracy of truck : 47 %
Note: Actual values vary based on training configuration (epochs, dtype, ZeRO stage, MoE settings).
Important Notes
- test_batch_size assumption: The per-class counting loop iterates
for i in range(test_batch_size), which assumes every batch contains exactlytest_batch_sizesamples. If the test set size is not evenly divisible bytest_batch_size, the last batch will be smaller and may cause an IndexError. The defaulttest_batch_size=4evenly divides the 10,000 CIFAR-10 test images (10000 / 4 = 2500 batches). - Rank-aware reporting: The
if model_engine.local_rank == 0check ensures output is printed only once in multi-GPU settings. All ranks still execute the full evaluation loop. - eval() mode:
model_engine.eval()disables dropout and batch normalization updates. While the baseline Net does not use these, callingeval()is best practice and necessary for MoE gating behavior.
Related Pages
- Principle:Microsoft_DeepSpeedExamples_Classification_Evaluation -- The principle this implementation realizes
- Implementation:Microsoft_DeepSpeedExamples_Net_DeepSpeed -- The model being evaluated
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_CIFAR -- Produces the
model_engine,local_device, andtarget_dtypeconsumed here - Implementation:Microsoft_DeepSpeedExamples_Net_Tutorial -- Baseline evaluation code this was extracted from
- Environment:Microsoft_DeepSpeedExamples_CIFAR10_Training_Environment