Overview
Custom_Metrics_Handler is a TorchServe handler for MNIST digit classification that demonstrates custom metrics integration. The MNISTDigitClassifier class extends ImageClassifier and creates three custom metrics (InferenceRequestCount, PreprocessCallCount, InitializeCallCount) during initialization, updates them throughout the handler lifecycle, and records them via the TorchServe metrics system using ts.metrics.dimension.Dimension and ts.metrics.metric_type_enum.MetricTypes.
Description
The MNISTDigitClassifier class serves as a reference implementation for adding custom metrics to a TorchServe handler. While it performs standard MNIST image classification (extending ImageClassifier), its primary purpose is demonstrating how to define, register, and update custom metrics that are exposed through the TorchServe metrics API on port 8082.
Key Responsibilities
- Custom Metric Registration: Creates three custom metrics during
initialize() using the TorchServe metrics API:
InferenceRequestCount: Tracks total inference requests received
PreprocessCallCount: Tracks total preprocess invocations
InitializeCallCount: Tracks initialization events (should be 1 per worker)
- Metric Updates: Increments counters in
preprocess() and postprocess() methods
- Image Preprocessing: Normalizes input images for MNIST classification
- Prediction: Applies argmax to model output for digit classification
Custom Metrics Architecture
| Metric Name |
Type |
Updated In |
Description
|
InferenceRequestCount |
Counter |
postprocess() |
Total number of inference requests processed
|
PreprocessCallCount |
Counter |
preprocess() |
Total number of preprocess method calls
|
InitializeCallCount |
Counter |
initialize() |
Number of handler initializations (1 per worker)
|
Dependencies
| Dependency |
Purpose
|
ts.torch_handler.image_classifier |
Parent class ImageClassifier (which extends BaseHandler)
|
ts.metrics.dimension.Dimension |
Defines metric dimensions (labels) for metric registration
|
ts.metrics.metric_type_enum.MetricTypes |
Enum for metric types (COUNTER, GAUGE, HISTOGRAM)
|
torch |
Tensor operations for preprocessing and postprocessing
|
Code Reference
Source Location
| File |
Lines |
Repository
|
examples/custom_metrics/mnist_handler.py |
L10-131 |
pytorch/serve
|
Signature
from ts.torch_handler.image_classifier import ImageClassifier
from ts.metrics.dimension import Dimension
from ts.metrics.metric_type_enum import MetricTypes
class MNISTDigitClassifier(ImageClassifier):
"""
MNIST digit classifier with custom TorchServe metrics.
Extends ImageClassifier to add InferenceRequestCount,
PreprocessCallCount, and InitializeCallCount custom metrics.
"""
def __init__(self):
super().__init__()
self.metrics = None
def initialize(self, context):
"""
Load model and register custom metrics.
Calls parent initialize() for model loading, then creates
three custom counter metrics with model_name and worker_name
dimensions.
Args:
context: TorchServe context with metrics system access.
"""
super().initialize(context)
self.metrics = context.metrics
# Define dimensions for custom metrics
model_name_dim = Dimension("ModelName", context.model_name)
worker_name_dim = Dimension("WorkerName", context.system_properties.get("gpu_id", "N/A"))
# Register custom metrics
self.metrics.add_counter(
"InferenceRequestCount",
value=0,
dimensions=[model_name_dim, worker_name_dim],
metric_type=MetricTypes.COUNTER,
)
self.metrics.add_counter(
"PreprocessCallCount",
value=0,
dimensions=[model_name_dim, worker_name_dim],
metric_type=MetricTypes.COUNTER,
)
self.metrics.add_counter(
"InitializeCallCount",
value=1,
dimensions=[model_name_dim, worker_name_dim],
metric_type=MetricTypes.COUNTER,
)
...
def preprocess(self, data):
"""
Normalize input images and update PreprocessCallCount metric.
Increments the PreprocessCallCount counter, then normalizes
the input image tensors for MNIST classification.
Args:
data (list): List of request input dicts with image data.
Returns:
torch.Tensor: Normalized image tensor on self.device.
"""
self.metrics.add_counter("PreprocessCallCount", 1)
# Normalize images for MNIST
...
def postprocess(self, data):
"""
Apply argmax for digit prediction and update InferenceRequestCount.
Increments the InferenceRequestCount counter, then applies
argmax to the model output to produce digit predictions.
Args:
data (torch.Tensor): Model output logits.
Returns:
list: List of predicted digit integers (0-9).
"""
self.metrics.add_counter("InferenceRequestCount", 1)
return data.argmax(dim=1).tolist()
I/O Contract
| Method |
Input |
Output |
Notes
|
initialize(context) |
TorchServe context |
None (registers 3 custom metrics) |
Creates InferenceRequestCount, PreprocessCallCount, InitializeCallCount
|
preprocess(data) |
List of request dicts with image bytes |
Normalized torch.Tensor |
Increments PreprocessCallCount
|
postprocess(data) |
Model output logits (torch.Tensor) |
List of digit predictions (int 0-9) |
Increments InferenceRequestCount, applies argmax
|
Metrics Output Format
Custom metrics are exposed via the metrics API on port 8082 in Prometheus format:
# Example Prometheus output on http://localhost:8082/metrics
#
# InferenceRequestCount{ModelName="mnist",WorkerName="0"} 42
# PreprocessCallCount{ModelName="mnist",WorkerName="0"} 42
# InitializeCallCount{ModelName="mnist",WorkerName="0"} 1
Usage Examples
Example 1: Package and serve the custom metrics handler
# Package the handler with the MNIST model
# torch-model-archiver --model-name mnist \
# --version 1.0 \
# --model-file examples/custom_metrics/mnist_model.py \
# --serialized-file mnist_cnn.pt \
# --handler examples/custom_metrics/mnist_handler.py \
# --export-path model_store
# Start TorchServe
# torchserve --start --model-store model_store \
# --models mnist=mnist.mar
Example 2: Send inference and check metrics
import requests
# Send an MNIST digit image for classification
with open("digit_3.png", "rb") as img:
response = requests.post(
"http://localhost:8080/predictions/mnist",
data=img.read(),
headers={"Content-Type": "application/octet-stream"},
)
print(response.json()) # [3]
# Check custom metrics
metrics_response = requests.get("http://localhost:8082/metrics")
print(metrics_response.text)
# Contains: InferenceRequestCount{ModelName="mnist",...} 1
# Contains: PreprocessCallCount{ModelName="mnist",...} 1
Example 3: Defining custom Dimension and MetricTypes
from ts.metrics.dimension import Dimension
from ts.metrics.metric_type_enum import MetricTypes
# Create dimensions to label metrics
model_dim = Dimension("ModelName", "my_model")
worker_dim = Dimension("WorkerName", "worker_0")
gpu_dim = Dimension("GPUId", "0")
# Available MetricTypes:
# MetricTypes.COUNTER - monotonically increasing counter
# MetricTypes.GAUGE - value that can go up and down
# MetricTypes.HISTOGRAM - distribution of values
# Register a custom gauge metric
context.metrics.add_counter(
"QueueDepth",
value=0,
dimensions=[model_dim, worker_dim],
metric_type=MetricTypes.GAUGE,
)
Related Pages