Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pytorch Serve BaseHandler

From Leeroopedia

Overview

BaseHandler is the base handler class in TorchServe that implements the inference handler pattern. It provides the default model loading logic (supporting TorchScript, eager mode, ONNX, and AOT-compiled models), automatic device selection (CUDA, XPU, MPS, XLA, CPU), and the handle() entry point that orchestrates the preprocess -> inference -> postprocess pipeline for every inference request.

Field Value
Implementation Name BaseHandler
Type API Doc
Workflow Model_Deployment
Domains Model_Serving, Design_Patterns
Knowledge Sources TorchServe
Last Updated 2026-02-13 00:00 GMT

Description

The BaseHandler class is the foundation of TorchServe's handler system. It is an abstract base class (inheriting from abc.ABC) that provides default implementations for all four stages of the inference pipeline. Custom handlers subclass BaseHandler and override only the methods that need custom logic.

Key Responsibilities

  • Model Loading: Detects model format from the file extension and manifest, then loads accordingly:
    • .pt files are loaded via torch.jit.load() (TorchScript)
    • Eager mode models are loaded via importlib + torch.load() for state_dict
    • .onnx files are loaded via ONNX Runtime session
    • .so files are loaded via torch._export.aot_load()
  • Device Selection: Automatically selects the appropriate compute device based on hardware availability and configuration
  • torch.compile: Optionally applies torch.compile() when a pt2 configuration block is present in the model YAML config
  • Label Mapping: Loads index_to_name.json for classifier output mapping
  • Metrics: Records HandlerTime for each request through the metrics system

Usage

from ts.torch_handler.base_handler import BaseHandler

Subclass BaseHandler and override methods as needed:

from ts.torch_handler.base_handler import BaseHandler
import torch


class MyHandler(BaseHandler):
    def preprocess(self, data):
        # Custom preprocessing
        inputs = []
        for row in data:
            val = row.get("body") or row.get("data")
            inputs.append(torch.tensor(val, dtype=torch.float32))
        return torch.stack(inputs).to(self.device)

    def postprocess(self, data):
        # Custom postprocessing
        return data.argmax(dim=1).tolist()

Code Reference

Source Location

File Lines Repository
ts/torch_handler/base_handler.py L119-467 pytorch/serve

Signature

class BaseHandler(abc.ABC):

    def __init__(self):
        self.model = None
        self.mapping = None
        self.device = None
        self.initialized = False
        self.context = None
        self.model_pt_path = None
        self.manifest = None
        self.map_location = None
        self.explain = False
        self.target = 0
        self.profiler_args = {}

    def initialize(self, context) -> None:
        """
        Load model weights and configure device.

        Args:
            context: JSON Object containing model artifacts parameters
                     (system_properties, manifest, model_yaml_config).

        Raises:
            RuntimeError: When no model weights could be loaded.
        """
        ...

    def preprocess(self, data) -> torch.Tensor:
        """
        Convert request input to a tensor.

        Args:
            data (list): List of the data from the request input.

        Returns:
            torch.Tensor: Tensor data of the input.
        """
        ...

    def inference(self, data, *args, **kwargs) -> torch.Tensor:
        """
        Execute model forward pass under torch.inference_mode().

        Args:
            data (torch.Tensor): Input tensor matching model input shape.

        Returns:
            torch.Tensor: Model prediction output.
        """
        ...

    def postprocess(self, data) -> list:
        """
        Convert model output tensor to a serializable list.

        Args:
            data (torch.Tensor): Prediction output tensor.

        Returns:
            list: Predicted output as a Python list.
        """
        ...

    def handle(self, data, context) -> list:
        """
        Entry point for inference. Runs preprocess -> inference -> postprocess.

        Args:
            data (list): Input data for prediction.
            context (Context): Model artifacts and system information.

        Returns:
            list: List of prediction results.
        """
        ...

Import

from ts.torch_handler.base_handler import BaseHandler

I/O Contract

Method Input Output Notes
__init__() None None Initializes instance attributes to None / defaults
initialize(context) context: Context object with system_properties, manifest, model_yaml_config None (sets self.model, self.device, self.initialized = True) Called once when the model worker starts
preprocess(data) data: list of request input dicts torch.Tensor on self.device Default calls torch.as_tensor(data, device=self.device)
inference(data) data: torch.Tensor torch.Tensor Runs under torch.inference_mode(), moves data to self.device
postprocess(data) data: torch.Tensor list Default calls data.tolist()
handle(data, context) data: list; context: Context list of prediction dicts Records HandlerTime metric in milliseconds

Device Selection Logic

The initialize method selects devices in the following priority order:

Priority Condition Device
1 torch.cuda.is_available() and gpu_id set cuda:{gpu_id}
2 TS_IPEX_GPU_ENABLE=true and torch.xpu.is_available() xpu:{gpu_id}
3 torch.backends.mps.is_available() and gpu_id set mps
4 XLA_AVAILABLE is True xm.xla_device()
5 Fallback cpu

Model Loading Logic

Condition Loader Method Framework
model_file present in manifest _load_pickled_model() Eager mode (state_dict)
File ends with .pt _load_torchscript_model() TorchScript via torch.jit.load()
File ends with .onnx and ONNX available setup_ort_session() ONNX Runtime
File ends with .so and AOT compile config _load_torch_export_aot_compile() torch._export.aot_load()

Usage Examples

Example 1: Default handler with TorchScript model

# The simplest case: use BaseHandler directly with a TorchScript model.
# No custom handler needed - set handler="base_handler" in the MAR config.

# model_config.yaml
# minWorkers: 1
# maxWorkers: 4
# batchSize: 8
# pt2:
#   compile:
#     enable: true
#     backend: inductor

Example 2: Custom handler with all stages overridden

from ts.torch_handler.base_handler import BaseHandler
import torch
import json


class NLPHandler(BaseHandler):
    def initialize(self, context):
        # Call parent to handle model loading and device selection
        super().initialize(context)
        # Load tokenizer from extra files
        model_dir = context.system_properties.get("model_dir")
        self.tokenizer = self._load_tokenizer(model_dir)

    def preprocess(self, data):
        texts = []
        for row in data:
            text = row.get("data") or row.get("body")
            if isinstance(text, (bytes, bytearray)):
                text = text.decode("utf-8")
            texts.append(text)
        tokens = self.tokenizer(
            texts, padding=True, truncation=True, return_tensors="pt"
        )
        return {k: v.to(self.device) for k, v in tokens.items()}

    def inference(self, data, *args, **kwargs):
        with torch.inference_mode():
            outputs = self.model(**data)
        return outputs.logits

    def postprocess(self, data):
        probs = torch.softmax(data, dim=1)
        predictions = torch.argmax(probs, dim=1)
        results = []
        for i, pred in enumerate(predictions):
            label = self.mapping.get(str(pred.item()), str(pred.item()))
            results.append({"label": label, "score": probs[i][pred].item()})
        return results

Example 3: handle() method execution flow

# The handle() method is called by TorchServe for each request batch.
# Internally it executes:

def handle(self, data, context):
    self.context = context
    metrics = self.context.metrics

    # Stage 1: Preprocess
    data_preprocess = self.preprocess(data)

    # Stage 2: Inference
    output = self.inference(data_preprocess)

    # Stage 3: Postprocess
    output = self.postprocess(output)

    # Record timing metric
    metrics.add_time("HandlerTime", duration_ms, None, "ms")
    return output

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment