Overview
Torchserve_Grpc_Client is the gRPC client library for TorchServe that provides Python functions and classes for communicating with TorchServe's gRPC inference and management endpoints. It supports unary inference, server-side streaming, and bidirectional streaming (InferStream2) using Protocol Buffer definitions (inference_pb2, management_pb2), and serves as both a testing utility and a reference implementation for gRPC client integration.
Description
The torchserve_grpc_client.py script (371 lines) implements a full-featured gRPC client for TorchServe's Protocol Buffer-based API. It provides both simple functions for common operations (register, unregister, infer) and advanced classes for streaming inference, making it suitable for testing, scripting, and as a reference for building production gRPC clients.
Key Responsibilities
- Stub Creation: Creates gRPC stubs for both the Inference and Management services from the generated Protocol Buffer code
- Unary Inference: Sends a single prediction request and receives a single response
- Server-Side Streaming: Sends a single request and receives a stream of responses via
infer_stream()
- Bidirectional Streaming: Supports full-duplex streaming inference via the
InferStream2 class and RequestIterator
- Model Management: Provides
register() and unregister() functions for model lifecycle management over gRPC
Architecture
The module is organized into three layers:
- Stub Layer:
get_inference_stub() and get_management_stub() create typed gRPC channel stubs
- Function Layer: Standalone functions (
infer, infer_stream, register, unregister) for simple operations
- Class Layer:
InferStream2, RequestIterator, and InferStream2SimpleClient for advanced bidirectional streaming
Code Reference
Source Location
| File |
Lines |
Repository
|
ts_scripts/torchserve_grpc_client.py |
L1-371 |
pytorch/serve
|
Stub Creation Functions
def get_inference_stub(host="localhost", port=7070):
"""
Create a gRPC stub for the TorchServe Inference API.
Args:
host (str): TorchServe gRPC host. Default: "localhost".
port (int): TorchServe gRPC port. Default: 7070.
Returns:
InferenceAPIsServiceStub: gRPC stub for inference operations.
"""
channel = grpc.insecure_channel(f"{host}:{port}")
return inference_pb2_grpc.InferenceAPIsServiceStub(channel)
def get_management_stub(host="localhost", port=7071):
"""
Create a gRPC stub for the TorchServe Management API.
Args:
host (str): TorchServe gRPC host. Default: "localhost".
port (int): TorchServe gRPC management port. Default: 7071.
Returns:
ManagementAPIsServiceStub: gRPC stub for management operations.
"""
channel = grpc.insecure_channel(f"{host}:{port}")
return management_pb2_grpc.ManagementAPIsServiceStub(channel)
Unary and Streaming Inference Functions
def infer(stub, model_name, input_data):
"""
Send a unary inference request.
Args:
stub: InferenceAPIsServiceStub from get_inference_stub().
model_name (str): Name of the registered model.
input_data (bytes): Serialized input data for prediction.
Returns:
PredictionResponse: gRPC response containing prediction output.
"""
...
def infer_stream(stub, model_name, input_data):
"""
Send an inference request and receive a stream of responses.
Used for models that produce output incrementally (e.g., text generation).
Args:
stub: InferenceAPIsServiceStub from get_inference_stub().
model_name (str): Name of the registered model.
input_data (bytes): Serialized input data for prediction.
Yields:
PredictionResponse: Streamed gRPC responses.
"""
...
Bidirectional Streaming Classes
class InferStream2:
"""
Bidirectional streaming inference client. Lines 135-214.
Manages a full-duplex gRPC stream where the client can send
multiple input chunks and receive multiple output chunks
concurrently. Useful for real-time inference scenarios.
Attributes:
stub: InferenceAPIsServiceStub for the gRPC channel.
model_name (str): Target model name.
responses: Iterator of streaming responses.
"""
def __init__(self, stub, model_name):
...
def send(self, input_data):
"""Send an input chunk to the stream."""
...
def recv(self):
"""Receive the next output chunk from the stream."""
...
def close(self):
"""Close the bidirectional stream."""
...
class RequestIterator:
"""
Iterator that yields gRPC request messages for streaming. Lines 216-234.
Used internally by InferStream2 to feed requests into the
bidirectional stream as they become available.
"""
def __init__(self):
self._queue = queue.Queue()
...
def add(self, request):
"""Add a request to the iterator queue."""
...
def close(self):
"""Signal end of requests."""
...
def __iter__(self):
return self
def __next__(self):
...
class InferStream2SimpleClient:
"""
Simplified bidirectional streaming client. Lines 237-278.
Wraps InferStream2 with a simpler interface for common
bidirectional streaming use cases.
"""
def __init__(self, host="localhost", port=7070, model_name=None):
...
def infer_stream2(self, input_data_list):
"""
Send multiple inputs and collect all streamed responses.
Args:
input_data_list (list[bytes]): List of input data chunks.
Returns:
list: Collected response objects.
"""
...
Model Management Functions
def register(stub, model_name, mar_url):
"""
Register a model via the gRPC Management API.
Args:
stub: ManagementAPIsServiceStub from get_management_stub().
model_name (str): Name for the model.
mar_url (str): URL or path to the .mar archive.
Returns:
ManagementResponse: gRPC response confirming registration.
"""
...
def unregister(stub, model_name):
"""
Unregister a model via the gRPC Management API.
Args:
stub: ManagementAPIsServiceStub from get_management_stub().
model_name (str): Name of the model to unregister.
Returns:
ManagementResponse: gRPC response confirming unregistration.
"""
...
Import
# Run as a standalone script:
# python ts_scripts/torchserve_grpc_client.py [OPTIONS]
# When imported:
from ts_scripts.torchserve_grpc_client import (
get_inference_stub,
get_management_stub,
infer,
infer_stream,
register,
unregister,
InferStream2,
InferStream2SimpleClient,
)
I/O Contract
| Function/Class |
Input |
Output |
Notes
|
get_inference_stub() |
host, port |
InferenceAPIsServiceStub |
Creates insecure gRPC channel on port 7070
|
get_management_stub() |
host, port |
ManagementAPIsServiceStub |
Creates insecure gRPC channel on port 7071
|
infer() |
stub, model_name, input_data (bytes) |
PredictionResponse |
Unary request/response
|
infer_stream() |
stub, model_name, input_data (bytes) |
Iterator of PredictionResponse |
Server-side streaming
|
InferStream2 |
stub, model_name |
Bidirectional stream |
Full-duplex send/recv
|
RequestIterator |
None |
Iterable of requests |
Thread-safe queue-based iterator
|
InferStream2SimpleClient |
host, port, model_name |
Collected responses list |
Simplified bidirectional client
|
register() |
stub, model_name, mar_url |
ManagementResponse |
Model registration via gRPC
|
unregister() |
stub, model_name |
ManagementResponse |
Model unregistration via gRPC
|
Protocol Buffer Dependencies
| Proto Module |
Description
|
inference_pb2 |
Generated Python code from inference service proto definition
|
inference_pb2_grpc |
Generated gRPC stubs for inference service
|
management_pb2 |
Generated Python code from management service proto definition
|
management_pb2_grpc |
Generated gRPC stubs for management service
|
Usage Examples
Example 1: Simple unary inference
from ts_scripts.torchserve_grpc_client import get_inference_stub, infer
# Create inference stub
stub = get_inference_stub(host="localhost", port=7070)
# Read input data
with open("kitten.jpg", "rb") as f:
input_data = f.read()
# Run inference
response = infer(stub, "resnet18", input_data)
print(response.prediction)
Example 2: Streaming inference for text generation
from ts_scripts.torchserve_grpc_client import get_inference_stub, infer_stream
import json
stub = get_inference_stub()
prompt = json.dumps({"prompt": "Once upon a time"}).encode("utf-8")
# Receive streamed tokens
for response in infer_stream(stub, "llama2", prompt):
print(response.prediction.decode("utf-8"), end="", flush=True)
Example 3: Model registration and unregistration via gRPC
from ts_scripts.torchserve_grpc_client import (
get_management_stub,
register,
unregister,
)
mgmt_stub = get_management_stub(host="localhost", port=7071)
# Register a model
reg_response = register(mgmt_stub, "squeezenet", "squeezenet1_1.mar")
print(f"Registration: {reg_response.msg}")
# Unregister the model
unreg_response = unregister(mgmt_stub, "squeezenet")
print(f"Unregistration: {unreg_response.msg}")
Example 4: Bidirectional streaming with InferStream2
from ts_scripts.torchserve_grpc_client import InferStream2SimpleClient
# Create a simplified bidirectional streaming client
client = InferStream2SimpleClient(
host="localhost",
port=7070,
model_name="streaming_model",
)
# Send multiple input chunks and collect responses
input_chunks = [b"chunk1", b"chunk2", b"chunk3"]
responses = client.infer_stream2(input_chunks)
for resp in responses:
print(resp.prediction)
Related Pages