Implementation:Kserve Kserve MNIST CNN Net
| Knowledge Sources | |
|---|---|
| Domains | Image Classification, TorchServe |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for defining a convolutional neural network architecture for MNIST handwritten digit classification using PyTorch provided by the KServe sample code.
Description
Net extends torch.nn.Module and implements a CNN architecture for classifying 28x28 grayscale MNIST digit images into 10 classes. The architecture consists of:
- Two convolutional layers:
conv1(1 input channel, 32 output channels, 3x3 kernel) andconv2(32 input channels, 64 output channels, 3x3 kernel) - Two dropout layers:
dropout1(2D dropout with p=0.25) anddropout2(2D dropout with p=0.5) - Two fully connected layers:
fc1(9216 to 128) andfc2(128 to 10)
The forward() method chains: conv1 -> ReLU -> conv2 -> max_pool2d(2) -> dropout1 -> flatten -> fc1 -> ReLU -> dropout2 -> fc2 -> log_softmax.
Usage
Use this class as the model definition file when packaging an MNIST CNN model into a MAR archive with torch-model-archiver for deployment on TorchServe via KServe.
Code Reference
Source Location
- Repository: Kserve_Kserve
- File: docs/samples/v1beta1/torchserve/model-archiver/model-store/mnist/mnist.py
- Lines: 1-28
Signature
class Net(nn.Module):
def __init__(self):
...
def forward(self, x):
...
Import
from mnist import Net
I/O Contract
Inputs
Constructor
| Name | Type | Required | Description |
|---|---|---|---|
| (none) | -- | -- | No arguments; architecture is fixed for MNIST (1-channel 28x28 input) |
forward()
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Input tensor of shape (batch_size, 1, 28, 28) representing grayscale MNIST images |
Outputs
forward()
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Log-softmax probabilities of shape (batch_size, 10) for 10 digit classes |
Usage Examples
Basic Usage
import torch
from mnist import Net
# Create the MNIST CNN model
model = Net()
# Load trained weights
model.load_state_dict(torch.load("mnist_cnn.pt"))
model.eval()
# Run inference on a sample image
sample = torch.randn(1, 1, 28, 28)
output = model(sample)
predicted_digit = output.argmax(dim=1).item()
print(f"Predicted digit: {predicted_digit}")
TorchServe Model Archiver
# Package with torch-model-archiver:
# torch-model-archiver --model-name mnist \
# --version 1.0 \
# --model-file mnist.py \
# --serialized-file mnist_cnn.pt \
# --handler mnist_handler.py