Overview
Service.predict and TsModelLoader.load are the core Python-side components of the TorchServe inference pipeline. Service.predict(batch) decodes incoming request batches, invokes the handler's handle() method, validates the output, records metrics, and creates the binary protocol response. TsModelLoader.load() constructs the Service instance by loading the handler module, creating the service wrapper, and initializing the model.
Description
Service Class
The Service class is a wrapper around the handler entry point. It manages the Context object (which holds model metadata, system properties, metrics, and YAML configuration) and provides the predict() method that the backend worker calls for each request batch.
Key responsibilities:
- Request Decoding: The static method
retrieve_data_for_inference(batch) parses the binary protocol batch into request IDs, headers, and input data.
- Context Management: Sets request IDs, request processors, metrics, and the client socket on the
Context object before each prediction.
- Handler Invocation: Calls the handler's
handle() method (stored as self._entry_point).
- Output Validation: Ensures the handler returns a
list and that its length matches the input batch size.
- Error Handling: Catches
MemoryError, CUDA OOM, PredictionException, and general exceptions, returning appropriate HTTP error codes.
- Metrics Recording: Records
PredictionTime in milliseconds for each batch.
TsModelLoader Class
The TsModelLoader is responsible for constructing a Service instance from a model directory:
- Manifest Loading: Reads
MAR-INF/MANIFEST.json to discover model metadata.
- Handler Resolution: Attempts to import the handler as a custom module (
_load_handler_file); falls back to a built-in handler (_load_default_handler).
- Entry Point Discovery: Checks for a function-based entry point first, then falls back to a class-based entry point with
handle() and initialize() methods.
- Envelope Wrapping: If an envelope is specified, wraps the handler's
handle() method with the envelope's handle() method for request format adaptation.
- Service Creation: Instantiates
Service with the model name, directory, manifest, entry point, GPU ID, batch size, and metrics cache.
- Initialization: Calls
initialize_fn(service.context) to trigger model loading in the handler.
Usage
from ts.service import Service
from ts.model_loader import TsModelLoader
Code Reference
Source Location
| File |
Lines |
Description |
Repository
|
ts/service.py |
L113-182 |
Service.predict() method |
pytorch/serve
|
ts/service.py |
L18-54 |
Service.__init__() |
pytorch/serve
|
ts/service.py |
L60-108 |
Service.retrieve_data_for_inference() |
pytorch/serve
|
ts/model_loader.py |
L68-145 |
TsModelLoader.load() |
pytorch/serve
|
Signature
Service
class Service(object):
"""Wrapper for custom entry_point."""
def __init__(
self,
model_name: str,
model_dir: str,
manifest: dict,
entry_point: callable,
gpu: int,
batch_size: int,
limit_max_image_pixels: bool = True,
metrics_cache=None,
):
"""
Initialize the Service wrapper.
Reads model YAML config from the manifest (if present),
creates a Context object, and stores the handler entry point.
Args:
model_name (str): Name of the model.
model_dir (str): Path to the model directory.
manifest (dict): Parsed MANIFEST.json contents.
entry_point (callable): The handler's handle() function.
gpu (int): GPU device ID (None for CPU).
batch_size (int): Batch size for inference.
limit_max_image_pixels (bool): Limit max image pixels for PIL.
metrics_cache: MetricsCacheYamlImpl instance.
"""
...
def predict(self, batch: list) -> bytearray:
"""
Execute inference on a batch of requests.
Args:
batch (list): List of request dicts from the binary protocol.
Each dict contains:
- "requestId" (bytes): Unique request identifier
- "parameters" (list): List of parameter dicts with
"name", "contentType", and "value" fields
- "headers" (list, optional): Request-level headers
Returns:
bytearray: Binary protocol response created by
create_predict_response().
"""
...
@staticmethod
def retrieve_data_for_inference(batch: list) -> tuple:
"""
Parse a raw batch into headers, input data, and request ID mapping.
Args:
batch (list): Raw request batch from binary protocol.
Returns:
tuple: (headers, input_batch, req_to_id_map)
- headers (list[RequestProcessor]): Per-request headers
- input_batch (list[dict]): Per-request input data
- req_to_id_map (dict[int, str]): Batch index to request ID
Raises:
ValueError: If batch is None.
"""
...
TsModelLoader
class TsModelLoader(ModelLoader):
"""TorchServe 1.0 Model Loader."""
def load(
self,
model_name: str,
model_dir: str,
handler: Optional[str] = None,
gpu_id: Optional[int] = None,
batch_size: Optional[int] = None,
envelope: Optional[str] = None,
limit_max_image_pixels: Optional[bool] = True,
metrics_cache: Optional[MetricsCacheYamlImpl] = None,
) -> Service:
"""
Load a model and return a configured Service instance.
Args:
model_name (str): Name of the model.
model_dir (str): Path to the extracted model directory.
handler (str): Handler module path or built-in name.
gpu_id (int): GPU device ID (None for CPU).
batch_size (int): Inference batch size.
envelope (str): Request envelope name (e.g., "json", "body").
limit_max_image_pixels (bool): Limit max image pixels.
metrics_cache: MetricsCacheYamlImpl instance.
Returns:
Service: Configured and initialized Service instance.
Raises:
ValueError: If handler module cannot be loaded or has
invalid entry point structure.
"""
...
Import
from ts.service import Service
from ts.model_loader import TsModelLoader
I/O Contract
Service.__init__()
| Parameter |
Type |
Required |
Default |
Description
|
model_name |
str |
Yes |
-- |
Model identifier
|
model_dir |
str |
Yes |
-- |
Path to extracted model artifacts
|
manifest |
dict |
Yes |
-- |
Parsed MANIFEST.json
|
entry_point |
callable |
Yes |
-- |
Handler's handle(data, context) function
|
gpu |
int |
Yes |
-- |
GPU ID or None for CPU
|
batch_size |
int |
Yes |
-- |
Batch size
|
limit_max_image_pixels |
bool |
No |
True |
PIL max image pixel limit
|
metrics_cache |
MetricsCacheYamlImpl |
No |
None |
Metrics cache object
|
Service.predict()
| Parameter |
Type |
Description
|
batch |
list |
List of request dicts from the binary protocol
|
| Return |
Type |
Description
|
| Response |
bytearray |
Binary protocol response with predictions and status codes
|
TsModelLoader.load()
| Parameter |
Type |
Required |
Default |
Description
|
model_name |
str |
Yes |
-- |
Model name
|
model_dir |
str |
Yes |
-- |
Model directory path
|
handler |
Optional[str] |
No |
None |
Handler module path
|
gpu_id |
Optional[int] |
No |
None |
GPU device ID
|
batch_size |
Optional[int] |
No |
None |
Batch size
|
envelope |
Optional[str] |
No |
None |
Envelope class name
|
limit_max_image_pixels |
Optional[bool] |
No |
True |
PIL limit
|
metrics_cache |
Optional[MetricsCacheYamlImpl] |
No |
None |
Metrics cache
|
| Return |
Type |
Description
|
| service |
Service |
Fully initialized and ready-to-serve Service instance
|
Error Response Codes
| Exception |
HTTP Code |
Message
|
MemoryError |
507 |
"Out of resources"
|
| CUDA error / CUDA out of memory |
507 |
"Out of resources"
|
PredictionException(message, code) |
Custom (code) |
Custom (message)
|
General Exception |
503 |
"Prediction failed"
|
| Non-list return type |
503 |
"Invalid model predict output"
|
| Batch size mismatch |
503 |
"number of batch response mismatched"
|
Usage Examples
Example 1: How predict() is called by the worker
# This is the internal flow - the backend worker calls predict() for each batch.
# The binary protocol delivers a batch like this:
batch = [
{
"requestId": b"req-001",
"parameters": [
{
"name": "body",
"contentType": "application/octet-stream",
"value": b"<raw image bytes>"
}
],
"headers": [
{"name": b"Content-Type", "value": b"image/jpeg"}
]
},
{
"requestId": b"req-002",
"parameters": [
{
"name": "body",
"contentType": "application/octet-stream",
"value": b"<raw image bytes>"
}
],
}
]
# The worker calls:
response = service.predict(batch)
# response is a bytearray in OTF binary protocol format
Example 2: Loading a model with TsModelLoader
from ts.model_loader import TsModelLoader
loader = TsModelLoader()
service = loader.load(
model_name="squeezenet",
model_dir="/tmp/models/squeezenet/",
handler="image_classifier",
gpu_id=0,
batch_size=8,
envelope=None,
limit_max_image_pixels=True,
metrics_cache=None,
)
# service is now ready to call predict()
# The handler has been initialized (model loaded, device configured)
print(f"Model loaded: {service.context.model_name}")
Example 3: retrieve_data_for_inference() output structure
# Given a batch from the binary protocol:
headers, input_batch, req_id_map = Service.retrieve_data_for_inference(batch)
# headers: [RequestProcessor({'body': {'content-type': 'image/jpeg'}}), ...]
# input_batch: [{'body': b'<image bytes>'}, {'body': b'<image bytes>'}]
# req_id_map: {0: 'req-001', 1: 'req-002'}
Example 4: Handler entry point resolution in TsModelLoader
# The loader resolves handlers in this priority order:
# 1. Custom handler file with function entry point:
# handler="my_handler:custom_handle"
# -> imports my_handler module, uses custom_handle function
# 2. Custom handler file with class entry point:
# handler="my_handler.py"
# -> imports module, finds single class with handle() method
# 3. Built-in handler:
# handler="image_classifier"
# -> imports ts.torch_handler.image_classifier
Example 5: Service constructor YAML config loading
# When the manifest contains a configFile reference, Service.__init__
# automatically loads it:
manifest = {
"model": {
"modelName": "bert",
"serializedFile": "model.pt",
"handler": "handler.py",
"modelVersion": "1.0",
"configFile": "model_config.yaml" # <-- triggers YAML loading
}
}
# Service.__init__ calls:
# get_yaml_config(os.path.join(model_dir, "model_config.yaml"))
# and passes the result to Context as model_yaml_config
Related Pages