Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pytorch Serve BasePippyHandler

From Leeroopedia
Field Value
Page Type Implementation (API Doc)
Title BasePippyHandler
Implements Principle:Pytorch_Serve_Pipeline_Parallelism
Source ts/torch_handler/distributed/base_pippy_handler.py, ts/handler_utils/distributed/pt_pippy.py, examples/large_models/Huggingface_pippy/pippy_handler.py
Repository TorchServe
Last Updated 2026-02-13 00:00 GMT

Overview

BasePippyHandler is the base handler class for PiPPy pipeline parallelism in TorchServe. It extends BaseHandler to initialize the distributed RPC workers required by PiPPy. The companion module pt_pippy provides two key utility functions: initialize_rpc_workers() for setting up the RPC communication layer, and get_pipeline_driver() for compiling a model into a multi-stage pipeline distributed across GPUs.

Description

The PiPPy integration consists of three components:

1. BasePippyHandler (ts/torch_handler/distributed/base_pippy_handler.py): A base handler that reads LOCAL_RANK and WORLD_SIZE from environment variables (set by torchrun), determines the local device, and initializes the RPC workers. Custom handlers inherit from this class and override initialize() to load their model and compile the pipeline.

2. RPC Worker Initialization (initialize_rpc_workers() in ts/handler_utils/distributed/pt_pippy.py): Configures TensorPipeRpcBackendOptions with timeout and thread count from the YAML config, sets up device mappings between all workers, and calls rpc.init_rpc() to establish the communication network.

3. Pipeline Driver (get_pipeline_driver() in ts/handler_utils/distributed/pt_pippy.py): Takes a loaded model, traces it using FX (with PiPPyHFTracer for HuggingFace models), splits it into equal-sized stages using split_into_equal_size(world_size), and compiles the distributed pipeline using pippy.all_compile(). For HuggingFace models, it injects the pipeline forward method back into the original model object.

Usage

Code Reference

Source Location:

  • ts/torch_handler/distributed/base_pippy_handler.py (lines 13-23)
  • ts/handler_utils/distributed/pt_pippy.py (lines 20-134)
  • examples/large_models/Huggingface_pippy/pippy_handler.py (lines 24-157)

Signature -- BasePippyHandler:

class BasePippyHandler(BaseHandler, ABC):
    """
    Base default handler to set up rpc workers for PiPPy large model inference
    """

    def initialize(self, ctx):
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.world_size = int(os.environ["WORLD_SIZE"])
        n_devs = torch.cuda.device_count()
        self.device = self.local_rank % n_devs
        initialize_rpc_workers(self.local_rank, self.world_size, ctx)

Signature -- initialize_rpc_workers:

def initialize_rpc_workers(local_rank, world_size, ctx):
    # Reads from ctx.model_yaml_config["pippy"]["rpc_timeout"]
    # Reads from ctx.model_yaml_config["pippy"]["num_worker_threads"]
    options = rpc.TensorPipeRpcBackendOptions(
        num_worker_threads=num_worker_threads, rpc_timeout=rpc_timeout
    )
    n_devs = torch.cuda.device_count()
    dev_id = local_rank % n_devs
    for i in range(world_size):
        options.set_device_map(f"worker{i}", {dev_id: i % n_devs})
    rpc.init_rpc(
        f"worker{local_rank}",
        rank=local_rank,
        world_size=world_size,
        rpc_backend_options=options,
    )

Signature -- get_pipeline_driver:

def get_pipeline_driver(model, world_size, ctx):
    """Returns a pipeline driver for the given model.
    Args:
        model (torch.nn.Module): The model to pipeline.
        world_size (int): The number of pipeline stages.
        ctx (Context): The context containing configuration information.
    Returns:
        torch.nn.Sequential: The pipeline driver for the model.
    """

Import:

from ts.torch_handler.distributed.base_pippy_handler import BasePippyHandler
from ts.handler_utils.distributed.pt_pippy import initialize_rpc_workers, get_pipeline_driver

External Dependencies:

  • pippy (PyTorch Pipelines: split_into_equal_size, all_compile)
  • pippy.hf (PiPPyHFTracer, inject_pipeline_forward)
  • torch.distributed.rpc (TensorPipe RPC backend)
  • transformers (HuggingFace model loading)

I/O Contract

Inputs to BasePippyHandler.initialize():

  • ctx (Context): TorchServe context object containing:
    • ctx.model_yaml_config["pippy"]["rpc_timeout"] (int): RPC timeout in seconds
    • ctx.model_yaml_config["pippy"]["num_worker_threads"] (int): Number of RPC worker threads
    • ctx.model_yaml_config["pippy"]["model_type"] (str): "HF" for HuggingFace models
    • ctx.model_yaml_config["pippy"]["input_names"] (list[str]): Input argument names for FX tracing
    • ctx.model_yaml_config["pippy"]["chunks"] (int, optional): Number of microbatches, defaults to 1
    • ctx.model_yaml_config["handler"]["model_path"] (str): Path to model checkpoints

Environment Variables (set by torchrun):

  • LOCAL_RANK: Local rank of this process on the node
  • WORLD_SIZE: Total number of processes across all nodes

Inputs to get_pipeline_driver():

  • model (torch.nn.Module): The loaded model (can be on meta device for memory efficiency)
  • world_size (int): Number of pipeline stages
  • ctx (Context): TorchServe context with YAML configuration

Output of get_pipeline_driver():

  • For HuggingFace models (model_type="HF"): The original model with pipeline forward injected
  • For other models: The PiPPy pipeline driver object

Usage Examples

Custom handler for HuggingFace causal LM with PiPPy:

import torch
from abc import ABC
from transformers import AutoModelForCausalLM, AutoTokenizer
from ts.torch_handler.distributed.base_pippy_handler import BasePippyHandler
from ts.handler_utils.distributed.pt_pippy import get_pipeline_driver

class TransformersSeqClassifierHandler(BasePippyHandler, ABC):
    def __init__(self):
        super(TransformersSeqClassifierHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        super().initialize(ctx)
        model_path = ctx.model_yaml_config["handler"]["model_path"]
        self.device = self.local_rank

        with torch.device("meta"):
            self.model = AutoModelForCausalLM.from_pretrained(
                model_path, use_cache=False, torch_dtype=torch.float16
            )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, return_tensors="pt")
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.max_length = ctx.model_yaml_config["handler"]["max_length"]
        self.model = get_pipeline_driver(self.model, self.world_size, ctx)
        self.initialized = True

    def preprocess(self, requests):
        input_texts = [data.get("data") or data.get("body") for data in requests]
        input_ids_batch = []
        for input_text in input_texts:
            if isinstance(input_text, (bytes, bytearray)):
                input_text = input_text.decode("utf-8")
            inputs = self.tokenizer.encode_plus(
                input_text, max_length=self.max_length,
                pad_to_max_length=True, add_special_tokens=True,
                return_tensors="pt",
            )
            input_ids_batch.append(inputs["input_ids"])
        return torch.cat(input_ids_batch, dim=0).to(self.device)

    def inference(self, input_batch):
        outputs = self.model.generate(input_batch, max_length=60)
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

    def postprocess(self, inference_output):
        return inference_output

Corresponding model-config.yaml:

minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 300
parallelType: "pp"
deviceType: "gpu"
torchrun:
    nproc-per-node: 4
pippy:
    rpc_timeout: 1800
    model_type: "HF"
    chunks: 1
    input_names: ["input_ids"]
    num_worker_threads: 128
handler:
    model_path: "/path/to/model/checkpoints"
    max_length: 50
    max_new_tokens: 60
    manual_seed: 40
    dtype: fp16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment