Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve TransformersSeqClassifierHandler

From Leeroopedia
Field Value
Page Type Implementation
Title TransformersSeqClassifierHandler
Type API Doc
Short Description Generalized TorchServe handler for HuggingFace Transformers that supports sequence classification, token classification, question answering, and text generation through configuration-driven mode branching
Domains NLP, Model_Serving
Source examples/Huggingface_Transformers/Transformer_handler_generalized.py:L24-513
Knowledge Sources TorchServe
Workflow HuggingFace_Transformer_Serving
Last Updated 2026-02-13 00:00 GMT

Overview

TransformersSeqClassifierHandler is a unified TorchServe handler class that extends BaseHandler to serve HuggingFace Transformer models across four NLP tasks. The handler reads task configuration from the model's YAML config, loads the appropriate model class and tokenizer, and routes preprocessing, inference, and explanation logic based on the configured mode. It supports pretrained and TorchScript model loading, BetterTransformer optimization, torch.compile, model parallelism, and Captum-based explainability.

Description

The handler implements the standard TorchServe lifecycle methods (initialize, preprocess, inference, postprocess) plus a get_insights method for model explainability. Each method branches on self.setup_config["mode"] to handle the four supported NLP tasks.

Usage

The handler is specified when creating the model archive:

torch-model-archiver --model-name bert_seq_classification \
  --version 1.0 \
  --handler Transformer_handler_generalized.py \
  --extra-files "index_to_name.json,model-config.yaml" \
  --serialized-file Transformer_model/pytorch_model.bin \
  --export-path model_store

Code Reference

Source Location

Field Value
Repository pytorch/serve
File examples/Huggingface_Transformers/Transformer_handler_generalized.py
Lines L24-513

Class Definition

class TransformersSeqClassifierHandler(BaseHandler):
    """
    Transformers handler class for sequence, token classification and question answering.
    """

Method Signatures

__init__(self) (L29-32)

def __init__(self):
    super(TransformersSeqClassifierHandler, self).__init__()
    self.setup_config = None
    self.initialized = False

Initializes the handler with setup_config set to None and initialized flag set to False.

initialize(self, ctx) (L34-154)

def initialize(self, ctx):
    """In this initialize function, the BERT model is loaded and
    the Layer Integrated Gradients Algorithm for Captum Explanations
    is initialized here.
    Args:
        ctx (context): It is a JSON Object containing information
        pertaining to the model artifacts parameters.
    """

Loads the model and tokenizer based on configuration. Key operations:

  • Reads handler config from ctx.model_yaml_config
  • Loads model based on save_mode: TorchScript via torch.jit.load() or pretrained via AutoModelFor* classes
  • Optionally applies BetterTransformer via BetterTransformer.transform(self.model)
  • Optionally enables GPT-2 model parallelism via self.model.parallelize()
  • Loads tokenizer: GPT2TokenizerFast for GPT-2, AutoTokenizer otherwise
  • Sets model to eval mode
  • Optionally applies torch.compile() with configured backend and mode
  • Loads index_to_name.json label mapping for classification tasks

preprocess(self, requests) (L156-229)

def preprocess(self, requests):
    """Basic text preprocessing, based on the user's choice of application mode.
    Args:
        requests (str): The Input data in the form of text is passed on to the preprocess
        function.
    Returns:
        list : The preprocess function returns a list of Tensor for the size of the word tokens.
    """

Tokenizes input text based on task mode:

  • sequence_classification, token_classification, text_generation: Encodes single text input with tokenizer.encode_plus()
  • question_answering: Parses input as dict with "question" and "context" keys, encodes as pair
  • Handles Captum explanation input format: {"text": "...", "target": N}
  • Batches multiple requests by concatenating input_ids and attention_mask tensors

Returns (input_ids_batch, attention_mask_batch) tuple.

inference(self, input_batch) (L231-346)

@torch.inference_mode
def inference(self, input_batch):
    """Predict the class (or classes) of the received text using the
    serialized transformers checkpoint.
    Args:
        input_batch (list): List of Text Tensors from the pre-process function is passed here
    Returns:
        list : It returns a list of the predicted value for the input text
    """

Runs model inference with task-specific output processing:

  • sequence_classification: argmax of logits, maps to label via self.mapping
  • question_answering: Finds start/end positions via argmax, decodes answer span
  • token_classification: Per-token argmax, maps to labels from self.mapping["label_list"]
  • text_generation: Calls model.generate() with top_p=0.95, top_k=60, do_sample=True

Decorated with @torch.inference_mode to disable gradient computation and view tracking.

postprocess(self, inference_output) (L348-355)

def postprocess(self, inference_output):
    """Post Process Function converts the predicted response into Torchserve readable format.
    Args:
        inference_output (list): It contains the predicted response of the input text.
    Returns:
        (list): Returns a list of the Predictions and Explanations.
    """
    return inference_output

Pass-through method that returns inference output unchanged.

get_insights(self, input_batch, text, target) (L357-427)

def get_insights(self, input_batch, text, target):
    """This function initialize and calls the layer integrated gradient to get word importance
    of the input text if captum explanation has been selected through setup_config
    Args:
        input_batch (int): Batches of tokens IDs of text
        text (str): The Text specified in the input request
        target (int): The Target can be set to any acceptable label under the user's discretion.
    Returns:
        (list): Returns a list of importances and words.
    """

Computes Captum Layer Integrated Gradients explanations. See Implementation:Pytorch_Serve_Captum_Explanations for full details.

Import

import ast
import json
import logging
import os

import torch
import transformers
from captum.attr import LayerIntegratedGradients
from transformers import (
    AutoModelForCausalLM,
    AutoModelForQuestionAnswering,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification,
    AutoTokenizer,
    GPT2TokenizerFast,
)

from ts.torch_handler.base_handler import BaseHandler

I/O Contract

Input

Method Input Format Description
initialize ctx TorchServe Context Contains model_yaml_config, manifest, system_properties
preprocess requests list of dicts Each dict has "data" or "body" key with text input
inference input_batch tuple(Tensor, Tensor) (input_ids_batch, attention_mask_batch) from preprocess
postprocess inference_output list Results from inference
get_insights input_batch, text, target mixed Token IDs, raw text, target class index

Input Text Formats by Mode:

  • sequence_classification / token_classification: Plain text string, or {"text": "...", "target": N} when Captum is enabled
  • question_answering: {"question": "...", "context": "..."}
  • text_generation: Plain text string (prompt)

Output

Method Output Format
preprocess (input_ids_batch, attention_mask_batch) tuple of 2D Tensors
inference (seq_class) List of label strings ["Accepted", "Not Accepted"]
inference (qa) List of answer strings ["the answer text"]
inference (token_class) List of token-label pairs ("token", "B-PER"), ...
inference (text_gen) List of generated texts ["The generated continuation..."]
postprocess Same as inference output list (pass-through)
get_insights List with response dict [{"words": [...], "importances": [...], "delta": float}]

Usage Examples

Example 1: Sequence Classification Request

# Input to TorchServe prediction endpoint
# POST /predictions/bert_seq_classification
# Body: "This movie was absolutely wonderful"

# Handler flow:
# preprocess: tokenizer.encode_plus("This movie was absolutely wonderful", ...)
# inference: model(input_ids, attention_mask) -> argmax -> self.mapping["1"] -> "Accepted"
# postprocess: ["Accepted"]
# Response: ["Accepted"]

Example 2: Question Answering Request

# Input:
# Body: '{"question": "Who is CEO of Tesla?", "context": "Elon Musk is the CEO of Tesla."}'

# Handler flow:
# preprocess: tokenizer.encode_plus("Who is CEO of Tesla?", "Elon Musk is the CEO of Tesla.", ...)
# inference: model(input_ids, attention_mask) -> start/end argmax -> decode tokens
# Response: ["Elon Musk"]

Example 3: Text Generation Request

# Input:
# Body: "Once upon a time"

# Handler flow:
# preprocess: tokenizer.encode_plus("Once upon a time", ...)
# inference: model.generate(input_ids, max_new_tokens=150, do_sample=True, top_p=0.95, top_k=60)
# Response: ["Once upon a time there was a small village..."]

Example 4: Explanation Request

# Input to TorchServe explanation endpoint
# POST /explanations/bert_seq_classification
# Body: '{"text": "This movie was great", "target": 1}'

# Handler flow:
# get_insights: LayerIntegratedGradients on embedding layer
# Response: [{"words": ["[CLS]", "this", "movie", "was", "great", "[SEP]"],
#             "importances": [0.01, 0.05, 0.15, 0.02, 0.77, 0.00],
#             "delta": 0.0012}]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment