Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pytorch Serve BaseDeepSpeedHandler

From Leeroopedia
Field Value
Page Type Implementation (API Doc)
Title BaseDeepSpeedHandler
Implements Principle:Pytorch_Serve_DeepSpeed_Inference
Source ts/torch_handler/distributed/base_deepspeed_handler.py, ts/handler_utils/distributed/deepspeed.py, examples/large_models/deepspeed/custom_handler.py
Repository TorchServe
Last Updated 2026-02-13 00:00 GMT

Overview

BaseDeepSpeedHandler is the base handler class for DeepSpeed tensor parallelism in TorchServe. It extends BaseHandler to set the local device based on the LOCAL_RANK environment variable. The companion function get_ds_engine() reads the DeepSpeed configuration from the model YAML config, loads checkpoint metadata, and calls deepspeed.init_inference() to create an optimized inference engine that automatically shards the model across GPUs.

Description

The DeepSpeed integration consists of three components:

1. BaseDeepSpeedHandler (ts/torch_handler/distributed/base_deepspeed_handler.py): A minimal base handler that reads the LOCAL_RANK environment variable and sets self.device to the corresponding GPU index. Custom handlers inherit from this class and call get_ds_engine() in their initialize() method.

2. DeepSpeed Engine Initialization (get_ds_engine() in ts/handler_utils/distributed/deepspeed.py): This function:

  • Reads the model directory from ctx.system_properties
  • Reads the model path from ctx.model_yaml_config["handler"]["model_path"]
  • Loads the DeepSpeed config JSON file path from ctx.model_yaml_config["deepspeed"]["config"]
  • Optionally reads a checkpoint specification from ctx.model_yaml_config["deepspeed"]["checkpoint"]
  • If a checkpoint is specified, calls create_checkpoints_json() to generate a DeepSpeed-compatible checkpoint index
  • Calls deepspeed.init_inference(model, config=ds_config, base_dir=model_path, checkpoint=checkpoint)
  • Returns the DeepSpeed inference engine

3. Checkpoint Index Generation (create_checkpoints_json()): Scans the model directory for checkpoint files (matching *.[bp][it][n] patterns, covering .bin and .pt files) and writes a JSON index in DeepSpeed format.

Usage

Code Reference

Source Location:

  • ts/torch_handler/distributed/base_deepspeed_handler.py (lines 8-14)
  • ts/handler_utils/distributed/deepspeed.py (lines 22-57)
  • examples/large_models/deepspeed/custom_handler.py (lines 16-135)

Signature -- BaseDeepSpeedHandler:

class BaseDeepSpeedHandler(BaseHandler, ABC):
    """
    Base default DeepSpeed handler.
    """

    def initialize(self, ctx: Context):
        self.device = int(os.getenv("LOCAL_RANK", 0))

Signature -- get_ds_engine:

def get_ds_engine(model, ctx: Context):
    """
    Create and return a DeepSpeed inference engine for the given model.

    Args:
        model: The PyTorch model to wrap with DeepSpeed inference.
        ctx (Context): TorchServe context containing model_yaml_config
                       with "deepspeed" and "handler" sections.

    Returns:
        deepspeed.InferenceEngine: The DeepSpeed inference engine wrapping the model.

    Raises:
        ValueError: If deepspeed config is missing from model_yaml_config,
                    or if the specified config file does not exist.
    """

Signature -- create_checkpoints_json:

def create_checkpoints_json(model_path, checkpoints_json):
    """
    Scan model_path for checkpoint files and write a DeepSpeed checkpoint index.

    Args:
        model_path (str): Path to the directory containing model checkpoint files.
        checkpoints_json (str): Path where the checkpoint index JSON will be written.
    """

Import:

from ts.torch_handler.distributed.base_deepspeed_handler import BaseDeepSpeedHandler
from ts.handler_utils.distributed.deepspeed import get_ds_engine

External Dependencies:

  • deepspeed (deepspeed.init_inference())
  • transformers (HuggingFace model loading)

I/O Contract

Inputs to BaseDeepSpeedHandler.initialize():

  • ctx (Context): TorchServe context object.

Environment Variables (set by torchrun):

  • LOCAL_RANK: Local rank of this process, defaults to 0. Used to set self.device.

Inputs to get_ds_engine():

  • model (torch.nn.Module): The loaded PyTorch model.
  • ctx (Context): TorchServe context containing:
    • ctx.system_properties["model_dir"] (str): Model archive extraction directory
    • ctx.model_yaml_config["handler"]["model_path"] (str): Path to model checkpoints
    • ctx.model_yaml_config["deepspeed"]["config"] (str): DeepSpeed config JSON filename
    • ctx.model_yaml_config["deepspeed"]["checkpoint"] (str, optional): Checkpoint index filename

DeepSpeed Config JSON (ds-config.json):

{
    "dtype": "torch.float16",
    "replace_with_kernel_inject": true,
    "tensor_parallel": {
        "tp_size": 2
    }
}

Output of get_ds_engine():

  • deepspeed.InferenceEngine: The inference engine. Access the sharded model via ds_engine.module.

Usage Examples

Custom handler for HuggingFace model with DeepSpeed:

import torch
from abc import ABC
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from ts.context import Context
from ts.handler_utils.distributed.deepspeed import get_ds_engine
from ts.torch_handler.distributed.base_deepspeed_handler import BaseDeepSpeedHandler

class TransformersSeqClassifierHandler(BaseDeepSpeedHandler, ABC):
    def __init__(self):
        super(TransformersSeqClassifierHandler, self).__init__()
        self.max_length = None
        self.max_new_tokens = None
        self.tokenizer = None
        self.initialized = False

    def initialize(self, ctx: Context):
        super().initialize(ctx)
        model_dir = ctx.system_properties.get("model_dir")
        self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
        self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
        model_name = ctx.model_yaml_config["handler"]["model_name"]
        model_path = ctx.model_yaml_config["handler"]["model_path"]
        seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
        torch.manual_seed(seed)

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        config = AutoConfig.from_pretrained(model_name)
        with torch.device("meta"):
            self.model = AutoModelForCausalLM.from_config(
                config, torch_dtype=torch.float16
            )
        self.model = self.model.eval()

        ds_engine = get_ds_engine(self.model, ctx)
        self.model = ds_engine.module
        self.initialized = True

    def preprocess(self, requests):
        input_texts = [data.get("data") or data.get("body") for data in requests]
        input_ids_batch, attention_mask_batch = [], []
        for input_text in input_texts:
            if isinstance(input_text, (bytes, bytearray)):
                input_text = input_text.decode("utf-8")
            inputs = self.tokenizer.encode_plus(
                input_text, max_length=self.max_length,
                padding=True, add_special_tokens=True,
                return_tensors="pt", truncation=True,
            )
            input_ids_batch.append(inputs["input_ids"])
            attention_mask_batch.append(inputs["attention_mask"])
        input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.device)
        attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
        return input_ids_batch, attention_mask_batch

    def inference(self, input_batch):
        input_ids_batch, attention_mask_batch = input_batch
        outputs = self.model.generate(
            input_ids_batch, attention_mask=attention_mask_batch,
            max_length=self.max_new_tokens,
        )
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

    def postprocess(self, inference_output):
        return inference_output

Corresponding model-config.yaml:

minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 120
parallelType: "tp"
deviceType: "gpu"
torchrun:
    nproc-per-node: 4
deepspeed:
    config: ds-config.json
handler:
    model_name: "facebook/opt-30b"
    model_path: "/path/to/model/checkpoints"
    max_length: 80
    max_new_tokens: 50
    manual_seed: 40

Packaging command:

torch-model-archiver --model-name opt-30b \
    --version 1.0 \
    --handler deepspeed_handler.py \
    --extra-files /path/to/checkpoints,ds-config.json \
    -r requirements.txt \
    --config-file model-config.yaml \
    --archive-format tgz

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment