Implementation:Tensorflow Serving Mnist Client Inference
| Knowledge Sources | |
|---|---|
| Domains | Testing, Inference |
| Last Updated | 2026-02-13 17:00 GMT |
Overview
Concrete tool for validating a deployed MNIST model by sending concurrent gRPC inference requests and computing error rate, provided by the TensorFlow Serving example scripts.
Description
The do_inference() function in mnist_client.py implements a concurrent gRPC client that validates a served MNIST model. It uses PredictionServiceStub to send asynchronous Predict requests with a configurable concurrency level, collects responses via callbacks, compares predictions against ground truth labels, and reports the classification error rate.
The client uses a _ResultCounter class for thread-safe tracking of active requests, completed requests, and errors, with a throttling mechanism that limits concurrent in-flight requests.
Usage
Use this after starting a TensorFlow Serving instance with an exported MNIST model. The client connects via gRPC and validates the serving pipeline end-to-end.
Code Reference
Source Location
- Repository: tensorflow/serving
- File: tensorflow_serving/example/mnist_client.py
- Lines: L123-153 (do_inference), L51-120 (_ResultCounter, _create_rpc_callback)
Signature
def do_inference(
hostport: str, # Host:port address of PredictionService (e.g. "localhost:8500")
work_dir: str, # Full path of working directory for test data
concurrency: int, # Maximum number of concurrent requests
num_tests: int # Number of test images to use
) -> float: # Returns classification error rate
"""Tests PredictionService with concurrent requests."""
Import
import grpc
import numpy
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import mnist_input_data
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hostport | str | Yes | Server address in "host:port" format |
| work_dir | str | Yes | Directory for MNIST test data |
| concurrency | int | Yes | Max concurrent in-flight requests |
| num_tests | int | Yes | Number of test images to evaluate |
Outputs
| Name | Type | Description |
|---|---|---|
| error_rate | float | Fraction of incorrect predictions (0.0 to 1.0) |
Usage Examples
Command Line Validation
# Validate with 100 test images
python tensorflow_serving/example/mnist_client.py \
--server=localhost:8500 \
--num_tests=100
# Validate with concurrent requests
python tensorflow_serving/example/mnist_client.py \
--server=localhost:8500 \
--num_tests=1000 \
--concurrency=10
Programmatic Usage
import grpc
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import tensorflow as tf
# 1. Create gRPC channel and stub
channel = grpc.insecure_channel('localhost:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# 2. Build predict request
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'predict_images'
request.inputs['images'].CopyFrom(
tf.make_tensor_proto(image_data, shape=[1, 784])
)
# 3. Send request (synchronous)
response = stub.Predict(request, 5.0) # 5 second timeout
# 4. Parse response
import numpy as np
scores = np.array(response.outputs['scores'].float_val)
predicted_class = np.argmax(scores)