Implementation:Pytorch Serve BaseDeepSpeedHandler
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Title | BaseDeepSpeedHandler |
| Implements | Principle:Pytorch_Serve_DeepSpeed_Inference |
| Source | ts/torch_handler/distributed/base_deepspeed_handler.py, ts/handler_utils/distributed/deepspeed.py, examples/large_models/deepspeed/custom_handler.py
|
| Repository | TorchServe |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
BaseDeepSpeedHandler is the base handler class for DeepSpeed tensor parallelism in TorchServe. It extends BaseHandler to set the local device based on the LOCAL_RANK environment variable. The companion function get_ds_engine() reads the DeepSpeed configuration from the model YAML config, loads checkpoint metadata, and calls deepspeed.init_inference() to create an optimized inference engine that automatically shards the model across GPUs.
Description
The DeepSpeed integration consists of three components:
1. BaseDeepSpeedHandler (ts/torch_handler/distributed/base_deepspeed_handler.py): A minimal base handler that reads the LOCAL_RANK environment variable and sets self.device to the corresponding GPU index. Custom handlers inherit from this class and call get_ds_engine() in their initialize() method.
2. DeepSpeed Engine Initialization (get_ds_engine() in ts/handler_utils/distributed/deepspeed.py): This function:
- Reads the model directory from
ctx.system_properties - Reads the model path from
ctx.model_yaml_config["handler"]["model_path"] - Loads the DeepSpeed config JSON file path from
ctx.model_yaml_config["deepspeed"]["config"] - Optionally reads a checkpoint specification from
ctx.model_yaml_config["deepspeed"]["checkpoint"] - If a checkpoint is specified, calls
create_checkpoints_json()to generate a DeepSpeed-compatible checkpoint index - Calls
deepspeed.init_inference(model, config=ds_config, base_dir=model_path, checkpoint=checkpoint) - Returns the DeepSpeed inference engine
3. Checkpoint Index Generation (create_checkpoints_json()): Scans the model directory for checkpoint files (matching *.[bp][it][n] patterns, covering .bin and .pt files) and writes a JSON index in DeepSpeed format.
Usage
Code Reference
Source Location:
ts/torch_handler/distributed/base_deepspeed_handler.py(lines 8-14)ts/handler_utils/distributed/deepspeed.py(lines 22-57)examples/large_models/deepspeed/custom_handler.py(lines 16-135)
Signature -- BaseDeepSpeedHandler:
class BaseDeepSpeedHandler(BaseHandler, ABC):
"""
Base default DeepSpeed handler.
"""
def initialize(self, ctx: Context):
self.device = int(os.getenv("LOCAL_RANK", 0))
Signature -- get_ds_engine:
def get_ds_engine(model, ctx: Context):
"""
Create and return a DeepSpeed inference engine for the given model.
Args:
model: The PyTorch model to wrap with DeepSpeed inference.
ctx (Context): TorchServe context containing model_yaml_config
with "deepspeed" and "handler" sections.
Returns:
deepspeed.InferenceEngine: The DeepSpeed inference engine wrapping the model.
Raises:
ValueError: If deepspeed config is missing from model_yaml_config,
or if the specified config file does not exist.
"""
Signature -- create_checkpoints_json:
def create_checkpoints_json(model_path, checkpoints_json):
"""
Scan model_path for checkpoint files and write a DeepSpeed checkpoint index.
Args:
model_path (str): Path to the directory containing model checkpoint files.
checkpoints_json (str): Path where the checkpoint index JSON will be written.
"""
Import:
from ts.torch_handler.distributed.base_deepspeed_handler import BaseDeepSpeedHandler
from ts.handler_utils.distributed.deepspeed import get_ds_engine
External Dependencies:
deepspeed(deepspeed.init_inference())transformers(HuggingFace model loading)
I/O Contract
Inputs to BaseDeepSpeedHandler.initialize():
ctx(Context): TorchServe context object.
Environment Variables (set by torchrun):
LOCAL_RANK: Local rank of this process, defaults to 0. Used to setself.device.
Inputs to get_ds_engine():
model(torch.nn.Module): The loaded PyTorch model.ctx(Context): TorchServe context containing:ctx.system_properties["model_dir"](str): Model archive extraction directoryctx.model_yaml_config["handler"]["model_path"](str): Path to model checkpointsctx.model_yaml_config["deepspeed"]["config"](str): DeepSpeed config JSON filenamectx.model_yaml_config["deepspeed"]["checkpoint"](str, optional): Checkpoint index filename
DeepSpeed Config JSON (ds-config.json):
{
"dtype": "torch.float16",
"replace_with_kernel_inject": true,
"tensor_parallel": {
"tp_size": 2
}
}
Output of get_ds_engine():
deepspeed.InferenceEngine: The inference engine. Access the sharded model viads_engine.module.
Usage Examples
Custom handler for HuggingFace model with DeepSpeed:
import torch
from abc import ABC
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from ts.context import Context
from ts.handler_utils.distributed.deepspeed import get_ds_engine
from ts.torch_handler.distributed.base_deepspeed_handler import BaseDeepSpeedHandler
class TransformersSeqClassifierHandler(BaseDeepSpeedHandler, ABC):
def __init__(self):
super(TransformersSeqClassifierHandler, self).__init__()
self.max_length = None
self.max_new_tokens = None
self.tokenizer = None
self.initialized = False
def initialize(self, ctx: Context):
super().initialize(ctx)
model_dir = ctx.system_properties.get("model_dir")
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = ctx.model_yaml_config["handler"]["model_path"]
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
with torch.device("meta"):
self.model = AutoModelForCausalLM.from_config(
config, torch_dtype=torch.float16
)
self.model = self.model.eval()
ds_engine = get_ds_engine(self.model, ctx)
self.model = ds_engine.module
self.initialized = True
def preprocess(self, requests):
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
inputs = self.tokenizer.encode_plus(
input_text, max_length=self.max_length,
padding=True, add_special_tokens=True,
return_tensors="pt", truncation=True,
)
input_ids_batch.append(inputs["input_ids"])
attention_mask_batch.append(inputs["attention_mask"])
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.device)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
return input_ids_batch, attention_mask_batch
def inference(self, input_batch):
input_ids_batch, attention_mask_batch = input_batch
outputs = self.model.generate(
input_ids_batch, attention_mask=attention_mask_batch,
max_length=self.max_new_tokens,
)
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
def postprocess(self, inference_output):
return inference_output
Corresponding model-config.yaml:
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 120
parallelType: "tp"
deviceType: "gpu"
torchrun:
nproc-per-node: 4
deepspeed:
config: ds-config.json
handler:
model_name: "facebook/opt-30b"
model_path: "/path/to/model/checkpoints"
max_length: 80
max_new_tokens: 50
manual_seed: 40
Packaging command:
torch-model-archiver --model-name opt-30b \
--version 1.0 \
--handler deepspeed_handler.py \
--extra-files /path/to/checkpoints,ds-config.json \
-r requirements.txt \
--config-file model-config.yaml \
--archive-format tgz
Related Pages
- Principle:Pytorch_Serve_DeepSpeed_Inference - Theory of tensor parallelism with DeepSpeed
- Pytorch_Serve_ParallelType_Config - Configuring parallelType for tensor parallelism
- Pytorch_Serve_Parallelism_Model_Config - Full model-config.yaml examples
- Pytorch_Serve_TorchModelServiceWorker - Worker process management with torchrun
- Environment:Pytorch_Serve_DeepSpeed_Environment - DeepSpeed library and distributed env vars
- Environment:Pytorch_Serve_CUDA_GPU_Environment - Multi-GPU CUDA environment
- Environment:Pytorch_Serve_Distributed_Training_Environment - Distributed process group and env vars
- Distributed_Computing
- Model_Parallelism