Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft DeepSpeedExamples Test Function CIFAR

From Leeroopedia
Revision as of 15:42, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Microsoft_DeepSpeedExamples_Test_Function_CIFAR.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Metadata

Field Value
Page Type Implementation
Repository Microsoft/DeepSpeedExamples
Title Test_Function_CIFAR
Type Function Doc
Source File training/cifar/cifar10_deepspeed.py
Lines 212-277
Implements Principle:Microsoft_DeepSpeedExamples_Classification_Evaluation

Overview

Concrete tool for evaluating CIFAR-10 classification accuracy with DeepSpeed mixed-precision support.

Description

The test() function in cifar10_deepspeed.py evaluates a trained DeepSpeed model on the CIFAR-10 test set. It computes both overall accuracy and per-class accuracy for all 10 CIFAR-10 classes, handling mixed-precision dtype casting and distributed rank-aware reporting.

The function performs the following steps:

  1. Defines the CIFAR-10 class names for human-readable reporting
  2. Creates a test DataLoader from the provided test dataset (non-distributed, no shuffling)
  3. Sets the model engine to evaluation mode via model_engine.eval()
  4. Iterates over the test set under torch.no_grad(), computing predictions
  5. Casts input images to the target dtype (fp16/bf16) if mixed precision was used during training
  6. Accumulates correct/total counts for both overall and per-class accuracy
  7. Reports results only on rank 0 to avoid duplicate output in distributed settings

This function is called at the end of training in main() at line 397:

test(model_engine, testset, local_device, target_dtype)

Code Reference

File: training/cifar/cifar10_deepspeed.py, Lines 212-277

def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
    """Test the network on the test data.

    Args:
        model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
        testset (torch.utils.data.Dataset): the test dataset.
        local_device (str): the local device name.
        target_dtype (torch.dtype): the target datatype for the test data.
        test_batch_size (int): the test batch size.

    """
    # The 10 classes for CIFAR10.
    classes = (
        "plane",
        "car",
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck",
    )

    # Define the test dataloader.
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=test_batch_size, shuffle=False, num_workers=0
    )

    # For total accuracy.
    correct, total = 0, 0
    # For accuracy per class.
    class_correct = list(0.0 for i in range(10))
    class_total = list(0.0 for i in range(10))

    # Start testing.
    model_engine.eval()
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            if target_dtype != None:
                images = images.to(target_dtype)
            outputs = model_engine(images.to(local_device))
            _, predicted = torch.max(outputs.data, 1)
            # Count the total accuracy.
            total += labels.size(0)
            correct += (predicted == labels.to(local_device)).sum().item()

            # Count the accuracy per class.
            batch_correct = (predicted == labels.to(local_device)).squeeze()
            for i in range(test_batch_size):
                label = labels[i]
                class_correct[label] += batch_correct[i].item()
                class_total[label] += 1

    if model_engine.local_rank == 0:
        print(
            f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
        )

        # For all classes, print the accuracy.
        for i in range(10):
            print(
                f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
            )

Signature

def test(
    model_engine: deepspeed.runtime.engine.DeepSpeedEngine,
    testset: torch.utils.data.Dataset,
    local_device: str,
    target_dtype: Optional[torch.dtype],
    test_batch_size: int = 4,
) -> None:
    """Evaluate the CIFAR-10 model and print accuracy metrics.

    Computes overall and per-class accuracy on the test set.
    Prints results only on rank 0 in distributed settings.
    """

I/O Contract

Inputs

Parameter Type Default Description
model_engine DeepSpeedEngine (required) The trained DeepSpeed engine wrapping the Net model. Must support .eval(), .local_rank, and forward pass __call__.
testset torch.utils.data.Dataset (required) CIFAR-10 test dataset (10,000 images). Loaded via torchvision.datasets.CIFAR10(train=False).
local_device str (required) Device name string (e.g., "cuda:0", "xpu:0") for placing test data on the correct device.
target_dtype torch.dtype or None (required) Data type for mixed precision: torch.half (fp16), torch.bfloat16 (bf16), or None (fp32, no conversion).
test_batch_size int 4 Number of images per test batch. Must evenly divide the test set size for accurate per-class counting.

Outputs

Output Type Description
Return value None The function has no return value
Printed (rank 0 only) stdout Overall accuracy: "Accuracy of the network on the N test images: XX %"
Printed (rank 0 only) stdout Per-class accuracy: "Accuracy of <class> : XX %" for each of 10 classes

Internal Data Flow

testset
    |
    v
DataLoader(batch_size=4, shuffle=False, num_workers=0)
    |
    v
for (images, labels) in testloader:
    |
    +-- images.to(target_dtype)    [if target_dtype is not None]
    |
    +-- images.to(local_device)    [move to GPU]
    |
    +-- model_engine(images)       [forward pass]
    |       |
    |       v
    |   outputs: (B, 10)           [raw logits]
    |       |
    |       v
    +-- torch.max(outputs, 1)      [argmax for predictions]
    |       |
    |       v
    +-- predicted == labels         [comparison for accuracy]
    |       |
    +-------+-- correct += sum     [overall accumulator]
    |       |
    +-------+-- class_correct[i]   [per-class accumulator]
    |
    v
if local_rank == 0:
    print(overall accuracy)
    print(per-class accuracy for all 10 classes)

Test DataLoader Configuration

Parameter Value Rationale
batch_size 4 (default) Small batch for consistent per-class counting
shuffle False Deterministic evaluation order
num_workers 0 Single-process data loading (avoids multiprocessing issues)

Note that unlike the training DataLoader (created by deepspeed.initialize() with DistributedSampler), the test DataLoader is created manually without distributed sampling. Every rank evaluates the full test set.

Usage Example

# Called at the end of main() after training completes:
def main(args):
    # ... training code ...

    ########################################################################
    # Step 4. Test the network on the test data.
    ########################################################################
    test(model_engine, testset, local_device, target_dtype)
# With a custom test batch size:
test(model_engine, testset, local_device, target_dtype, test_batch_size=32)

Expected Output Format

Accuracy of the network on the 10000 test images:  53 %
Accuracy of plane :  63 %
Accuracy of   car :  70 %
Accuracy of  bird :  34 %
Accuracy of   cat :  36 %
Accuracy of  deer :  42 %
Accuracy of   dog :  52 %
Accuracy of  frog :  68 %
Accuracy of horse :  51 %
Accuracy of  ship :  72 %
Accuracy of truck :  47 %

Note: Actual values vary based on training configuration (epochs, dtype, ZeRO stage, MoE settings).

Important Notes

  • test_batch_size assumption: The per-class counting loop iterates for i in range(test_batch_size), which assumes every batch contains exactly test_batch_size samples. If the test set size is not evenly divisible by test_batch_size, the last batch will be smaller and may cause an IndexError. The default test_batch_size=4 evenly divides the 10,000 CIFAR-10 test images (10000 / 4 = 2500 batches).
  • Rank-aware reporting: The if model_engine.local_rank == 0 check ensures output is printed only once in multi-GPU settings. All ranks still execute the full evaluation loop.
  • eval() mode: model_engine.eval() disables dropout and batch normalization updates. While the baseline Net does not use these, calling eval() is best practice and necessary for MoE gating behavior.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment