Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Wav2Vec2 Handler

From Leeroopedia

Overview

Wav2VecHandler is a TorchServe handler for speech-to-text inference using the Wav2Vec2 model from HuggingFace Transformers. It is a standalone handler class (not extending BaseHandler) that accepts raw WAV audio bytes, resamples to 16kHz if necessary, runs CTC-based speech recognition via AutoModelForCTC, and returns the decoded transcription text.

Field Value
Implementation Name Wav2Vec2_Handler
Type Example Handler
Workflow Speech_To_Text_Serving
Domains Speech_Recognition, Audio_Processing
Knowledge Sources Pytorch_Serve
Last Updated 2026-02-13 18:52 GMT

Description

The Wav2VecHandler class implements a complete speech-to-text pipeline in a single handle() method. Unlike most TorchServe handlers, it does not inherit from BaseHandler -- it is a plain Python object subclass with initialize() and handle() methods. The handler loads an AutoModelForCTC and AutoProcessor from the model directory, then processes WAV audio through resampling, feature extraction, CTC logit computation, and greedy argmax decoding.

Key Responsibilities

  • Model Loading: Loads AutoModelForCTC and AutoProcessor from the model directory via HuggingFace from_pretrained()
  • Audio Input: Accepts raw WAV bytes from request data, loaded via torchaudio.load() from an in-memory BytesIO buffer
  • Resampling: Automatically resamples audio to 16kHz using torchaudio.functional.resample() if the input sample rate differs
  • Feature Extraction: Processes audio waveform through AutoProcessor to produce model input values
  • CTC Decoding: Computes logits from the model, applies torch.argmax along the class axis, and decodes via the processor

Usage

from handler import Wav2VecHandler

The handler expects a pre-trained Wav2Vec2 model directory containing the model weights and processor configuration. The model is loaded during initialize().

# Create model archive
torch-model-archiver --model-name wav2vec2 \
    --handler handler.py \
    --extra-files "model_dir/" \
    --archive-format no-archive

Code Reference

Source Location

File Lines Description
examples/speech2text_wav2vec2/handler.py L1-49 Full handler module (48 lines)
examples/speech2text_wav2vec2/handler.py L7-49 Wav2VecHandler class definition
examples/speech2text_wav2vec2/handler.py L8-16 __init__() -- instance variable initialization
examples/speech2text_wav2vec2/handler.py L17-28 initialize(context) -- model and processor loading
examples/speech2text_wav2vec2/handler.py L30-48 handle(data, context) -- full audio-to-text pipeline

Signature

class Wav2VecHandler(object):

    def __init__(self):
        self._context = None
        self.initialized = False
        self.model = None
        self.processor = None
        self.device = None
        self.expected_sampling_rate = 16_000

    def initialize(self, context):
        """
        Load AutoModelForCTC and AutoProcessor from model directory.

        Selects CUDA device if available, otherwise CPU. Loads both
        model and processor via from_pretrained().

        Args:
            context: TorchServe context with system_properties.
        """
        ...

    def handle(self, data, context):
        """
        Full speech-to-text pipeline in a single method.

        1. Extract WAV bytes from request data
        2. Load audio via torchaudio.load() from BytesIO
        3. Resample to 16kHz if sample rate differs
        4. Process through AutoProcessor for feature extraction
        5. Run model forward pass to get CTC logits
        6. Argmax decode predicted IDs via processor.decode()

        Args:
            data (list): Single-element list with dict containing
                         "data" or "body" key with WAV bytes.
            context: TorchServe context (unused in handle).

        Returns:
            list: Single-element list with transcribed text string.
        """
        ...

Import

# Handler imports
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForCTC
import io

I/O Contract

Method Input Output Notes
__init__() None None Sets expected_sampling_rate = 16_000
initialize(context) Context with system_properties["model_dir"], optional "gpu_id" None (sets self.model, self.processor, self.initialized = True) CUDA or CPU device selection
handle(data, context) list[dict] with "data"/"body" containing raw WAV bytes list[str] -- single-element list with transcribed text Resamples to 16kHz if needed; uses greedy argmax decoding

Usage Examples

Example 1: Model Initialization

# From handler.py L17-28: initialize() loads model and processor
def initialize(self, context):
    self._context = context
    self.initialized = True
    properties = context.system_properties

    self.device = torch.device(
        "cuda:" + str(properties.get("gpu_id"))
        if torch.cuda.is_available()
        else "cpu"
    )

    model_dir = properties.get("model_dir")
    self.processor = AutoProcessor.from_pretrained(model_dir)
    self.model = AutoModelForCTC.from_pretrained(model_dir)

Example 2: Full Audio-to-Text Pipeline

# From handler.py L30-48: handle() processes WAV bytes to text
def handle(self, data, context):
    input = data[0].get("data")
    if input is None:
        input = data[0].get("body")

    # Load WAV from bytes
    model_input, sample_rate = torchaudio.load(
        io.BytesIO(input), format="WAV"
    )

    # Resample to 16kHz if needed
    if sample_rate != self.expected_sampling_rate:
        model_input = torchaudio.functional.resample(
            model_input, sample_rate, self.expected_sampling_rate
        )

    # Feature extraction and inference
    model_input = self.processor(
        model_input,
        sampling_rate=self.expected_sampling_rate,
        return_tensors="pt"
    ).input_values[0]

    logits = self.model(model_input)[0]
    pred_ids = torch.argmax(logits, axis=-1)[0]
    output = self.processor.decode(pred_ids)

    return [output]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment