Overview
BaseHandler is the base handler class in TorchServe that implements the inference handler pattern. It provides the default model loading logic (supporting TorchScript, eager mode, ONNX, and AOT-compiled models), automatic device selection (CUDA, XPU, MPS, XLA, CPU), and the handle() entry point that orchestrates the preprocess -> inference -> postprocess pipeline for every inference request.
Description
The BaseHandler class is the foundation of TorchServe's handler system. It is an abstract base class (inheriting from abc.ABC) that provides default implementations for all four stages of the inference pipeline. Custom handlers subclass BaseHandler and override only the methods that need custom logic.
Key Responsibilities
- Model Loading: Detects model format from the file extension and manifest, then loads accordingly:
.pt files are loaded via torch.jit.load() (TorchScript)
- Eager mode models are loaded via
importlib + torch.load() for state_dict
.onnx files are loaded via ONNX Runtime session
.so files are loaded via torch._export.aot_load()
- Device Selection: Automatically selects the appropriate compute device based on hardware availability and configuration
- torch.compile: Optionally applies
torch.compile() when a pt2 configuration block is present in the model YAML config
- Label Mapping: Loads
index_to_name.json for classifier output mapping
- Metrics: Records
HandlerTime for each request through the metrics system
Usage
from ts.torch_handler.base_handler import BaseHandler
Subclass BaseHandler and override methods as needed:
from ts.torch_handler.base_handler import BaseHandler
import torch
class MyHandler(BaseHandler):
def preprocess(self, data):
# Custom preprocessing
inputs = []
for row in data:
val = row.get("body") or row.get("data")
inputs.append(torch.tensor(val, dtype=torch.float32))
return torch.stack(inputs).to(self.device)
def postprocess(self, data):
# Custom postprocessing
return data.argmax(dim=1).tolist()
Code Reference
Source Location
| File |
Lines |
Repository
|
ts/torch_handler/base_handler.py |
L119-467 |
pytorch/serve
|
Signature
class BaseHandler(abc.ABC):
def __init__(self):
self.model = None
self.mapping = None
self.device = None
self.initialized = False
self.context = None
self.model_pt_path = None
self.manifest = None
self.map_location = None
self.explain = False
self.target = 0
self.profiler_args = {}
def initialize(self, context) -> None:
"""
Load model weights and configure device.
Args:
context: JSON Object containing model artifacts parameters
(system_properties, manifest, model_yaml_config).
Raises:
RuntimeError: When no model weights could be loaded.
"""
...
def preprocess(self, data) -> torch.Tensor:
"""
Convert request input to a tensor.
Args:
data (list): List of the data from the request input.
Returns:
torch.Tensor: Tensor data of the input.
"""
...
def inference(self, data, *args, **kwargs) -> torch.Tensor:
"""
Execute model forward pass under torch.inference_mode().
Args:
data (torch.Tensor): Input tensor matching model input shape.
Returns:
torch.Tensor: Model prediction output.
"""
...
def postprocess(self, data) -> list:
"""
Convert model output tensor to a serializable list.
Args:
data (torch.Tensor): Prediction output tensor.
Returns:
list: Predicted output as a Python list.
"""
...
def handle(self, data, context) -> list:
"""
Entry point for inference. Runs preprocess -> inference -> postprocess.
Args:
data (list): Input data for prediction.
context (Context): Model artifacts and system information.
Returns:
list: List of prediction results.
"""
...
Import
from ts.torch_handler.base_handler import BaseHandler
I/O Contract
| Method |
Input |
Output |
Notes
|
__init__() |
None |
None |
Initializes instance attributes to None / defaults
|
initialize(context) |
context: Context object with system_properties, manifest, model_yaml_config |
None (sets self.model, self.device, self.initialized = True) |
Called once when the model worker starts
|
preprocess(data) |
data: list of request input dicts |
torch.Tensor on self.device |
Default calls torch.as_tensor(data, device=self.device)
|
inference(data) |
data: torch.Tensor |
torch.Tensor |
Runs under torch.inference_mode(), moves data to self.device
|
postprocess(data) |
data: torch.Tensor |
list |
Default calls data.tolist()
|
handle(data, context) |
data: list; context: Context |
list of prediction dicts |
Records HandlerTime metric in milliseconds
|
Device Selection Logic
The initialize method selects devices in the following priority order:
| Priority |
Condition |
Device
|
| 1 |
torch.cuda.is_available() and gpu_id set |
cuda:{gpu_id}
|
| 2 |
TS_IPEX_GPU_ENABLE=true and torch.xpu.is_available() |
xpu:{gpu_id}
|
| 3 |
torch.backends.mps.is_available() and gpu_id set |
mps
|
| 4 |
XLA_AVAILABLE is True |
xm.xla_device()
|
| 5 |
Fallback |
cpu
|
Model Loading Logic
| Condition |
Loader Method |
Framework
|
model_file present in manifest |
_load_pickled_model() |
Eager mode (state_dict)
|
File ends with .pt |
_load_torchscript_model() |
TorchScript via torch.jit.load()
|
File ends with .onnx and ONNX available |
setup_ort_session() |
ONNX Runtime
|
File ends with .so and AOT compile config |
_load_torch_export_aot_compile() |
torch._export.aot_load()
|
Usage Examples
Example 1: Default handler with TorchScript model
# The simplest case: use BaseHandler directly with a TorchScript model.
# No custom handler needed - set handler="base_handler" in the MAR config.
# model_config.yaml
# minWorkers: 1
# maxWorkers: 4
# batchSize: 8
# pt2:
# compile:
# enable: true
# backend: inductor
Example 2: Custom handler with all stages overridden
from ts.torch_handler.base_handler import BaseHandler
import torch
import json
class NLPHandler(BaseHandler):
def initialize(self, context):
# Call parent to handle model loading and device selection
super().initialize(context)
# Load tokenizer from extra files
model_dir = context.system_properties.get("model_dir")
self.tokenizer = self._load_tokenizer(model_dir)
def preprocess(self, data):
texts = []
for row in data:
text = row.get("data") or row.get("body")
if isinstance(text, (bytes, bytearray)):
text = text.decode("utf-8")
texts.append(text)
tokens = self.tokenizer(
texts, padding=True, truncation=True, return_tensors="pt"
)
return {k: v.to(self.device) for k, v in tokens.items()}
def inference(self, data, *args, **kwargs):
with torch.inference_mode():
outputs = self.model(**data)
return outputs.logits
def postprocess(self, data):
probs = torch.softmax(data, dim=1)
predictions = torch.argmax(probs, dim=1)
results = []
for i, pred in enumerate(predictions):
label = self.mapping.get(str(pred.item()), str(pred.item()))
results.append({"label": label, "score": probs[i][pred].item()})
return results
Example 3: handle() method execution flow
# The handle() method is called by TorchServe for each request batch.
# Internally it executes:
def handle(self, data, context):
self.context = context
metrics = self.context.metrics
# Stage 1: Preprocess
data_preprocess = self.preprocess(data)
# Stage 2: Inference
output = self.inference(data_preprocess)
# Stage 3: Postprocess
output = self.postprocess(output)
# Record timing metric
metrics.add_time("HandlerTime", duration_ms, None, "ms")
return output
Related Pages