Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Kserve Kserve MNIST CNN Net

From Leeroopedia
Knowledge Sources
Domains Image Classification, TorchServe
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for defining a convolutional neural network architecture for MNIST handwritten digit classification using PyTorch provided by the KServe sample code.

Description

Net extends torch.nn.Module and implements a CNN architecture for classifying 28x28 grayscale MNIST digit images into 10 classes. The architecture consists of:

  • Two convolutional layers: conv1 (1 input channel, 32 output channels, 3x3 kernel) and conv2 (32 input channels, 64 output channels, 3x3 kernel)
  • Two dropout layers: dropout1 (2D dropout with p=0.25) and dropout2 (2D dropout with p=0.5)
  • Two fully connected layers: fc1 (9216 to 128) and fc2 (128 to 10)

The forward() method chains: conv1 -> ReLU -> conv2 -> max_pool2d(2) -> dropout1 -> flatten -> fc1 -> ReLU -> dropout2 -> fc2 -> log_softmax.

Usage

Use this class as the model definition file when packaging an MNIST CNN model into a MAR archive with torch-model-archiver for deployment on TorchServe via KServe.

Code Reference

Source Location

Signature

class Net(nn.Module):
    def __init__(self):
        ...

    def forward(self, x):
        ...

Import

from mnist import Net

I/O Contract

Inputs

Constructor

Name Type Required Description
(none) -- -- No arguments; architecture is fixed for MNIST (1-channel 28x28 input)

forward()

Name Type Required Description
x torch.Tensor Yes Input tensor of shape (batch_size, 1, 28, 28) representing grayscale MNIST images

Outputs

forward()

Name Type Description
output torch.Tensor Log-softmax probabilities of shape (batch_size, 10) for 10 digit classes

Usage Examples

Basic Usage

import torch
from mnist import Net

# Create the MNIST CNN model
model = Net()

# Load trained weights
model.load_state_dict(torch.load("mnist_cnn.pt"))
model.eval()

# Run inference on a sample image
sample = torch.randn(1, 1, 28, 28)
output = model(sample)
predicted_digit = output.argmax(dim=1).item()
print(f"Predicted digit: {predicted_digit}")

TorchServe Model Archiver

# Package with torch-model-archiver:
# torch-model-archiver --model-name mnist \
#     --version 1.0 \
#     --model-file mnist.py \
#     --serialized-file mnist_cnn.pt \
#     --handler mnist_handler.py

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment