Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Custom Metrics Handler

From Leeroopedia

Overview

Custom_Metrics_Handler is a TorchServe handler for MNIST digit classification that demonstrates custom metrics integration. The MNISTDigitClassifier class extends ImageClassifier and creates three custom metrics (InferenceRequestCount, PreprocessCallCount, InitializeCallCount) during initialization, updates them throughout the handler lifecycle, and records them via the TorchServe metrics system using ts.metrics.dimension.Dimension and ts.metrics.metric_type_enum.MetricTypes.

Field Value
Implementation Name Custom_Metrics_Handler
Type Custom Handler
Workflow Metrics_Monitoring
Domains Model_Serving, Observability
Knowledge Sources Pytorch_Serve
Last Updated 2026-02-13 18:52 GMT

Description

The MNISTDigitClassifier class serves as a reference implementation for adding custom metrics to a TorchServe handler. While it performs standard MNIST image classification (extending ImageClassifier), its primary purpose is demonstrating how to define, register, and update custom metrics that are exposed through the TorchServe metrics API on port 8082.

Key Responsibilities

  • Custom Metric Registration: Creates three custom metrics during initialize() using the TorchServe metrics API:
    • InferenceRequestCount: Tracks total inference requests received
    • PreprocessCallCount: Tracks total preprocess invocations
    • InitializeCallCount: Tracks initialization events (should be 1 per worker)
  • Metric Updates: Increments counters in preprocess() and postprocess() methods
  • Image Preprocessing: Normalizes input images for MNIST classification
  • Prediction: Applies argmax to model output for digit classification

Custom Metrics Architecture

Metric Name Type Updated In Description
InferenceRequestCount Counter postprocess() Total number of inference requests processed
PreprocessCallCount Counter preprocess() Total number of preprocess method calls
InitializeCallCount Counter initialize() Number of handler initializations (1 per worker)

Dependencies

Dependency Purpose
ts.torch_handler.image_classifier Parent class ImageClassifier (which extends BaseHandler)
ts.metrics.dimension.Dimension Defines metric dimensions (labels) for metric registration
ts.metrics.metric_type_enum.MetricTypes Enum for metric types (COUNTER, GAUGE, HISTOGRAM)
torch Tensor operations for preprocessing and postprocessing

Code Reference

Source Location

File Lines Repository
examples/custom_metrics/mnist_handler.py L10-131 pytorch/serve

Signature

from ts.torch_handler.image_classifier import ImageClassifier
from ts.metrics.dimension import Dimension
from ts.metrics.metric_type_enum import MetricTypes


class MNISTDigitClassifier(ImageClassifier):
    """
    MNIST digit classifier with custom TorchServe metrics.

    Extends ImageClassifier to add InferenceRequestCount,
    PreprocessCallCount, and InitializeCallCount custom metrics.
    """

    def __init__(self):
        super().__init__()
        self.metrics = None

    def initialize(self, context):
        """
        Load model and register custom metrics.

        Calls parent initialize() for model loading, then creates
        three custom counter metrics with model_name and worker_name
        dimensions.

        Args:
            context: TorchServe context with metrics system access.
        """
        super().initialize(context)
        self.metrics = context.metrics

        # Define dimensions for custom metrics
        model_name_dim = Dimension("ModelName", context.model_name)
        worker_name_dim = Dimension("WorkerName", context.system_properties.get("gpu_id", "N/A"))

        # Register custom metrics
        self.metrics.add_counter(
            "InferenceRequestCount",
            value=0,
            dimensions=[model_name_dim, worker_name_dim],
            metric_type=MetricTypes.COUNTER,
        )
        self.metrics.add_counter(
            "PreprocessCallCount",
            value=0,
            dimensions=[model_name_dim, worker_name_dim],
            metric_type=MetricTypes.COUNTER,
        )
        self.metrics.add_counter(
            "InitializeCallCount",
            value=1,
            dimensions=[model_name_dim, worker_name_dim],
            metric_type=MetricTypes.COUNTER,
        )
        ...

    def preprocess(self, data):
        """
        Normalize input images and update PreprocessCallCount metric.

        Increments the PreprocessCallCount counter, then normalizes
        the input image tensors for MNIST classification.

        Args:
            data (list): List of request input dicts with image data.

        Returns:
            torch.Tensor: Normalized image tensor on self.device.
        """
        self.metrics.add_counter("PreprocessCallCount", 1)
        # Normalize images for MNIST
        ...

    def postprocess(self, data):
        """
        Apply argmax for digit prediction and update InferenceRequestCount.

        Increments the InferenceRequestCount counter, then applies
        argmax to the model output to produce digit predictions.

        Args:
            data (torch.Tensor): Model output logits.

        Returns:
            list: List of predicted digit integers (0-9).
        """
        self.metrics.add_counter("InferenceRequestCount", 1)
        return data.argmax(dim=1).tolist()

I/O Contract

Method Input Output Notes
initialize(context) TorchServe context None (registers 3 custom metrics) Creates InferenceRequestCount, PreprocessCallCount, InitializeCallCount
preprocess(data) List of request dicts with image bytes Normalized torch.Tensor Increments PreprocessCallCount
postprocess(data) Model output logits (torch.Tensor) List of digit predictions (int 0-9) Increments InferenceRequestCount, applies argmax

Metrics Output Format

Custom metrics are exposed via the metrics API on port 8082 in Prometheus format:

# Example Prometheus output on http://localhost:8082/metrics
#
# InferenceRequestCount{ModelName="mnist",WorkerName="0"} 42
# PreprocessCallCount{ModelName="mnist",WorkerName="0"} 42
# InitializeCallCount{ModelName="mnist",WorkerName="0"} 1

Usage Examples

Example 1: Package and serve the custom metrics handler

# Package the handler with the MNIST model
# torch-model-archiver --model-name mnist \
#   --version 1.0 \
#   --model-file examples/custom_metrics/mnist_model.py \
#   --serialized-file mnist_cnn.pt \
#   --handler examples/custom_metrics/mnist_handler.py \
#   --export-path model_store

# Start TorchServe
# torchserve --start --model-store model_store \
#   --models mnist=mnist.mar

Example 2: Send inference and check metrics

import requests

# Send an MNIST digit image for classification
with open("digit_3.png", "rb") as img:
    response = requests.post(
        "http://localhost:8080/predictions/mnist",
        data=img.read(),
        headers={"Content-Type": "application/octet-stream"},
    )
    print(response.json())  # [3]

# Check custom metrics
metrics_response = requests.get("http://localhost:8082/metrics")
print(metrics_response.text)
# Contains: InferenceRequestCount{ModelName="mnist",...} 1
# Contains: PreprocessCallCount{ModelName="mnist",...} 1

Example 3: Defining custom Dimension and MetricTypes

from ts.metrics.dimension import Dimension
from ts.metrics.metric_type_enum import MetricTypes

# Create dimensions to label metrics
model_dim = Dimension("ModelName", "my_model")
worker_dim = Dimension("WorkerName", "worker_0")
gpu_dim = Dimension("GPUId", "0")

# Available MetricTypes:
# MetricTypes.COUNTER   - monotonically increasing counter
# MetricTypes.GAUGE     - value that can go up and down
# MetricTypes.HISTOGRAM - distribution of values

# Register a custom gauge metric
context.metrics.add_counter(
    "QueueDepth",
    value=0,
    dimensions=[model_dim, worker_dim],
    metric_type=MetricTypes.GAUGE,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment