Overview
GptHandler is a TorchServe handler for serving GPT Fast models. It extends BaseHandler and implements the full inference pipeline with support for torch.compile optimization, SentencePiece tokenization, tensor parallelism, streaming token generation, and speculative decoding. The handler is designed for high-performance autoregressive text generation.
Description
The GptHandler class (lines 28-363) provides a complete handler implementation for GPT-style large language models using the GPT Fast architecture. It overrides all major BaseHandler methods and adds custom generation logic with optional speculative decoding.
Key Responsibilities
- Model Loading: Loads the GPT Fast model and applies
torch.compile() for optimized inference during initialize()
- Tokenization: Uses SentencePiece tokenizer for text-to-token and token-to-text conversion
- Text Generation: Implements autoregressive token generation with configurable temperature and top-k sampling
- Speculative Decoding: Supports speculative decoding with a draft model for faster generation through
speculative_decode()
- Tensor Parallelism: Supports distributed inference across multiple GPUs
- Streaming Output: Supports streaming token-by-token responses back to the client
Core Methods
initialize(context): Loads model weights, initializes SentencePiece tokenizer, applies torch.compile(), and configures tensor parallelism
preprocess(data): Tokenizes input text using SentencePiece
inference(data): Calls generate() to produce output tokens
postprocess(data): Decodes generated tokens back to text
generate(prompt_tokens, max_new_tokens, ...): Core autoregressive generation loop
speculative_decode(prompt_tokens, draft_model, ...): Speculative decoding with draft model verification
Usage
# The handler is typically configured in model-config.yaml:
# handler:
# model_path: "path/to/gpt_fast_model"
# tokenizer_path: "path/to/tokenizer.model"
# compile: true
# speculative: false
# Creating a model archive with the GPT handler
torch-model-archiver --model-name gpt_fast \
--handler examples/large_models/gpt_fast/handler.py \
--config-file model-config.yaml \
--archive-format no-archive
Code Reference
Source Location
| File |
Lines |
Repository
|
examples/large_models/gpt_fast/handler.py |
L1-362 |
pytorch/serve
|
examples/large_models/gpt_fast/handler.py |
L28-363 |
GptHandler class definition
|
Signature
class GptHandler(BaseHandler):
"""
TorchServe handler for GPT Fast model inference with optional
speculative decoding and tensor parallelism support.
Attributes:
model: The compiled GPT Fast model.
tokenizer: SentencePiece tokenizer instance.
device (torch.device): Target compute device.
draft_model: Optional draft model for speculative decoding.
max_new_tokens (int): Maximum number of tokens to generate.
temperature (float): Sampling temperature.
top_k (int): Top-k sampling parameter.
"""
def initialize(self, context):
"""
Load model, tokenizer, and configure torch.compile.
Sets up tensor parallelism if multiple GPUs are available.
Applies torch.compile() for optimized inference.
Args:
context: TorchServe context with system_properties and model_yaml_config.
"""
...
def preprocess(self, data):
"""
Tokenize input text using SentencePiece.
Args:
data (list): List of request dicts with text input.
Returns:
torch.Tensor: Tokenized input tensor on device.
"""
...
def inference(self, data, *args, **kwargs):
"""
Generate text using autoregressive decoding.
Calls generate() or speculative_decode() depending on configuration.
Args:
data (torch.Tensor): Tokenized prompt tensor.
Returns:
torch.Tensor: Generated token IDs.
"""
...
def postprocess(self, data):
"""
Decode generated token IDs back to text.
Args:
data (torch.Tensor): Generated token IDs.
Returns:
list[str]: Decoded text strings.
"""
...
def generate(self, prompt_tokens, max_new_tokens, temperature=1.0,
top_k=None, callback=None):
"""
Core autoregressive generation loop.
Generates tokens one at a time using the compiled model,
sampling from the logits distribution with temperature and top-k.
Args:
prompt_tokens (torch.Tensor): Input token IDs.
max_new_tokens (int): Maximum tokens to generate.
temperature (float): Sampling temperature (default 1.0).
top_k (int|None): Top-k sampling filter.
callback (callable|None): Optional callback for streaming.
Returns:
torch.Tensor: Full sequence including prompt and generated tokens.
"""
...
def speculative_decode(self, prompt_tokens, draft_model, speculate_k,
temperature=1.0, top_k=None):
"""
Speculative decoding with a smaller draft model.
The draft model generates k candidate tokens, then the target
model verifies them in a single forward pass. Accepted tokens
are kept; the first rejected token is resampled from the target.
Args:
prompt_tokens (torch.Tensor): Input token IDs.
draft_model: Smaller draft model for candidate generation.
speculate_k (int): Number of speculative tokens per step.
temperature (float): Sampling temperature.
top_k (int|None): Top-k sampling filter.
Returns:
torch.Tensor: Generated sequence with verified tokens.
"""
...
Import
# Handler is loaded by TorchServe from the model archive.
# Internal imports used by the handler:
import torch
import sentencepiece as spm
from ts.torch_handler.base_handler import BaseHandler
I/O Contract
| Method |
Input |
Output |
Notes
|
initialize(context) |
context: Context with system_properties, model_yaml_config |
None (sets self.model, self.tokenizer, self.device) |
Applies torch.compile(); supports tensor parallelism
|
preprocess(data) |
data: list[dict] with text in "data" or "body" key |
torch.Tensor of token IDs on device |
Uses SentencePiece tokenizer
|
inference(data) |
data: torch.Tensor of prompt token IDs |
torch.Tensor of generated token IDs |
Calls generate() or speculative_decode()
|
postprocess(data) |
data: torch.Tensor of token IDs |
list[str] decoded text strings |
SentencePiece decode
|
generate(prompt_tokens, max_new_tokens, ...) |
Prompt tensor, max tokens, temperature, top_k |
torch.Tensor full sequence |
Core autoregressive loop
|
speculative_decode(prompt_tokens, draft_model, speculate_k, ...) |
Prompt tensor, draft model, speculation depth |
torch.Tensor verified sequence |
Draft-then-verify strategy
|
Request Format
{
"data": "Once upon a time in a land far away",
"max_new_tokens": 256,
"temperature": 0.8,
"top_k": 50
}
Response Format
{
"generated_text": "Once upon a time in a land far away, there lived a wise old dragon..."
}
Usage Examples
Example 1: Serving the GPT Fast model
# Start TorchServe with the GPT Fast model
torchserve --start --ncs --model-store model_store \
--models gpt_fast=gpt_fast.mar
# Send an inference request
curl -X POST http://localhost:8080/predictions/gpt_fast \
-H "Content-Type: application/json" \
-d '{"data": "The future of artificial intelligence is", "max_new_tokens": 128}'
Example 2: Configuration for speculative decoding
# model-config.yaml for speculative decoding
minWorkers: 1
maxWorkers: 1
handler:
model_path: "path/to/gpt_fast_7b"
tokenizer_path: "path/to/tokenizer.model"
compile: true
speculative: true
draft_model_path: "path/to/gpt_fast_68m"
speculate_k: 5
max_new_tokens: 256
temperature: 0.8
top_k: 200
Example 3: Tensor-parallel multi-GPU configuration
# model-config.yaml for tensor parallelism across 4 GPUs
parallelType: "tp"
deviceType: "gpu"
deviceIds: [0, 1, 2, 3]
minWorkers: 1
maxWorkers: 1
handler:
model_path: "path/to/gpt_fast_70b"
tokenizer_path: "path/to/tokenizer.model"
compile: true
speculative: false
max_new_tokens: 512
Related Pages