Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve GptHandler

From Leeroopedia
Revision as of 13:45, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pytorch_Serve_GptHandler.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains LLM_Serving, Inference, Speculative_Decoding
Last Updated 2026-02-13 18:52 GMT

Overview

GptHandler is a TorchServe handler for serving GPT Fast models. It extends BaseHandler and implements the full inference pipeline with support for torch.compile optimization, SentencePiece tokenization, tensor parallelism, streaming token generation, and speculative decoding. The handler is designed for high-performance autoregressive text generation.

Description

The GptHandler class (lines 28-363) provides a complete handler implementation for GPT-style large language models using the GPT Fast architecture. It overrides all major BaseHandler methods and adds custom generation logic with optional speculative decoding.

Key Responsibilities

  • Model Loading: Loads the GPT Fast model and applies torch.compile() for optimized inference during initialize()
  • Tokenization: Uses SentencePiece tokenizer for text-to-token and token-to-text conversion
  • Text Generation: Implements autoregressive token generation with configurable temperature and top-k sampling
  • Speculative Decoding: Supports speculative decoding with a draft model for faster generation through speculative_decode()
  • Tensor Parallelism: Supports distributed inference across multiple GPUs
  • Streaming Output: Supports streaming token-by-token responses back to the client

Core Methods

  • initialize(context): Loads model weights, initializes SentencePiece tokenizer, applies torch.compile(), and configures tensor parallelism
  • preprocess(data): Tokenizes input text using SentencePiece
  • inference(data): Calls generate() to produce output tokens
  • postprocess(data): Decodes generated tokens back to text
  • generate(prompt_tokens, max_new_tokens, ...): Core autoregressive generation loop
  • speculative_decode(prompt_tokens, draft_model, ...): Speculative decoding with draft model verification

Usage

# The handler is typically configured in model-config.yaml:
# handler:
#   model_path: "path/to/gpt_fast_model"
#   tokenizer_path: "path/to/tokenizer.model"
#   compile: true
#   speculative: false
# Creating a model archive with the GPT handler
torch-model-archiver --model-name gpt_fast \
    --handler examples/large_models/gpt_fast/handler.py \
    --config-file model-config.yaml \
    --archive-format no-archive

Code Reference

Source Location

File Lines Repository
examples/large_models/gpt_fast/handler.py L1-362 pytorch/serve
examples/large_models/gpt_fast/handler.py L28-363 GptHandler class definition

Signature

class GptHandler(BaseHandler):
    """
    TorchServe handler for GPT Fast model inference with optional
    speculative decoding and tensor parallelism support.

    Attributes:
        model: The compiled GPT Fast model.
        tokenizer: SentencePiece tokenizer instance.
        device (torch.device): Target compute device.
        draft_model: Optional draft model for speculative decoding.
        max_new_tokens (int): Maximum number of tokens to generate.
        temperature (float): Sampling temperature.
        top_k (int): Top-k sampling parameter.
    """

    def initialize(self, context):
        """
        Load model, tokenizer, and configure torch.compile.

        Sets up tensor parallelism if multiple GPUs are available.
        Applies torch.compile() for optimized inference.

        Args:
            context: TorchServe context with system_properties and model_yaml_config.
        """
        ...

    def preprocess(self, data):
        """
        Tokenize input text using SentencePiece.

        Args:
            data (list): List of request dicts with text input.

        Returns:
            torch.Tensor: Tokenized input tensor on device.
        """
        ...

    def inference(self, data, *args, **kwargs):
        """
        Generate text using autoregressive decoding.

        Calls generate() or speculative_decode() depending on configuration.

        Args:
            data (torch.Tensor): Tokenized prompt tensor.

        Returns:
            torch.Tensor: Generated token IDs.
        """
        ...

    def postprocess(self, data):
        """
        Decode generated token IDs back to text.

        Args:
            data (torch.Tensor): Generated token IDs.

        Returns:
            list[str]: Decoded text strings.
        """
        ...

    def generate(self, prompt_tokens, max_new_tokens, temperature=1.0,
                 top_k=None, callback=None):
        """
        Core autoregressive generation loop.

        Generates tokens one at a time using the compiled model,
        sampling from the logits distribution with temperature and top-k.

        Args:
            prompt_tokens (torch.Tensor): Input token IDs.
            max_new_tokens (int): Maximum tokens to generate.
            temperature (float): Sampling temperature (default 1.0).
            top_k (int|None): Top-k sampling filter.
            callback (callable|None): Optional callback for streaming.

        Returns:
            torch.Tensor: Full sequence including prompt and generated tokens.
        """
        ...

    def speculative_decode(self, prompt_tokens, draft_model, speculate_k,
                           temperature=1.0, top_k=None):
        """
        Speculative decoding with a smaller draft model.

        The draft model generates k candidate tokens, then the target
        model verifies them in a single forward pass. Accepted tokens
        are kept; the first rejected token is resampled from the target.

        Args:
            prompt_tokens (torch.Tensor): Input token IDs.
            draft_model: Smaller draft model for candidate generation.
            speculate_k (int): Number of speculative tokens per step.
            temperature (float): Sampling temperature.
            top_k (int|None): Top-k sampling filter.

        Returns:
            torch.Tensor: Generated sequence with verified tokens.
        """
        ...

Import

# Handler is loaded by TorchServe from the model archive.
# Internal imports used by the handler:
import torch
import sentencepiece as spm
from ts.torch_handler.base_handler import BaseHandler

I/O Contract

Method Input Output Notes
initialize(context) context: Context with system_properties, model_yaml_config None (sets self.model, self.tokenizer, self.device) Applies torch.compile(); supports tensor parallelism
preprocess(data) data: list[dict] with text in "data" or "body" key torch.Tensor of token IDs on device Uses SentencePiece tokenizer
inference(data) data: torch.Tensor of prompt token IDs torch.Tensor of generated token IDs Calls generate() or speculative_decode()
postprocess(data) data: torch.Tensor of token IDs list[str] decoded text strings SentencePiece decode
generate(prompt_tokens, max_new_tokens, ...) Prompt tensor, max tokens, temperature, top_k torch.Tensor full sequence Core autoregressive loop
speculative_decode(prompt_tokens, draft_model, speculate_k, ...) Prompt tensor, draft model, speculation depth torch.Tensor verified sequence Draft-then-verify strategy

Request Format

{
    "data": "Once upon a time in a land far away",
    "max_new_tokens": 256,
    "temperature": 0.8,
    "top_k": 50
}

Response Format

{
    "generated_text": "Once upon a time in a land far away, there lived a wise old dragon..."
}

Usage Examples

Example 1: Serving the GPT Fast model

# Start TorchServe with the GPT Fast model
torchserve --start --ncs --model-store model_store \
    --models gpt_fast=gpt_fast.mar

# Send an inference request
curl -X POST http://localhost:8080/predictions/gpt_fast \
    -H "Content-Type: application/json" \
    -d '{"data": "The future of artificial intelligence is", "max_new_tokens": 128}'

Example 2: Configuration for speculative decoding

# model-config.yaml for speculative decoding
minWorkers: 1
maxWorkers: 1
handler:
    model_path: "path/to/gpt_fast_7b"
    tokenizer_path: "path/to/tokenizer.model"
    compile: true
    speculative: true
    draft_model_path: "path/to/gpt_fast_68m"
    speculate_k: 5
    max_new_tokens: 256
    temperature: 0.8
    top_k: 200

Example 3: Tensor-parallel multi-GPU configuration

# model-config.yaml for tensor parallelism across 4 GPUs
parallelType: "tp"
deviceType: "gpu"
deviceIds: [0, 1, 2, 3]
minWorkers: 1
maxWorkers: 1
handler:
    model_path: "path/to/gpt_fast_70b"
    tokenizer_path: "path/to/tokenizer.model"
    compile: true
    speculative: false
    max_new_tokens: 512

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment