Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Inferentia2 OPT Handler

From Leeroopedia

Overview

LLMHandler is a TorchServe handler for serving OPT large language models on AWS Inferentia2 hardware using Neuron cores. It extends BaseHandler and ABC, providing tensor parallel model compilation via transformers_neuronx, automatic Neuron core validation, and a full inference pipeline from tokenization through text generation to decoded output.

Field Value
Implementation Name Inferentia2_OPT_Handler
Type Example Handler
Workflow Neuron_Accelerated_LLM_Serving
Domains LLM_Serving, Hardware_Acceleration
Knowledge Sources Pytorch_Serve
Last Updated 2026-02-13 18:52 GMT

Description

The LLMHandler class implements the full inference lifecycle for OPT models compiled for AWS Inferentia2 chips. During initialization, it validates that the requested number of Neuron cores (configured via tp_degree) are available, sets the NEURON_RT_NUM_CORES environment variable, loads an AutoTokenizer from HuggingFace, and compiles the model to Neuron using OPTForSampling.from_pretrained() followed by model.to_neuron().

Key Responsibilities

  • Neuron Core Allocation: Sets NEURON_RT_NUM_CORES and validates available cores against the requested tp_degree
  • Model Compilation: Loads OPT weights via OPTForSampling.from_pretrained() with tensor parallelism degree and AMP settings, then compiles to Neuron hardware
  • Tokenization: Uses AutoTokenizer with padding set to eos_token and configurable max_length
  • Dynamic Batch Padding: Pads partial batches to the configured batchSize using nn.ConstantPad1d before inference, then truncates results
  • Text Generation: Calls model.sample() for autoregressive generation up to max_length tokens

Usage

from inf2_handler import LLMHandler

The handler is configured through a model YAML config:

# model-config.yaml for Inferentia2 OPT
handler:
    model_name: "facebook/opt-13b"
    tp_degree: 2
    manual_seed: 9
    amp: "f16"
    max_length: 50
batchSize: 2

Code Reference

Source Location

File Lines Description
examples/large_models/inferentia2/opt/inf2_handler.py L1-162 Full handler module (161 lines)
examples/large_models/inferentia2/opt/inf2_handler.py L17-162 LLMHandler class definition
examples/large_models/inferentia2/opt/inf2_handler.py L26-74 initialize(ctx) -- Neuron core setup, tokenizer load, model compilation
examples/large_models/inferentia2/opt/inf2_handler.py L76-94 preprocess(requests) -- batch tokenization with padding
examples/large_models/inferentia2/opt/inf2_handler.py L96-117 encode_input_text(input_text) -- single text encoding via AutoTokenizer
examples/large_models/inferentia2/opt/inf2_handler.py L119-152 inference(input_batch) -- padded batch generation and decode
examples/large_models/inferentia2/opt/inf2_handler.py L154-162 postprocess(inference_output) -- identity passthrough

Signature

class LLMHandler(BaseHandler, ABC):

    def __init__(self):
        super(LLMHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        """
        Load and compile OPT model for Inferentia2.

        Reads tp_degree, amp, model_name, manual_seed from
        ctx.model_yaml_config["handler"]. Validates Neuron core
        availability. Loads AutoTokenizer and compiles model via
        OPTForSampling.from_pretrained() + to_neuron().

        Args:
            ctx: TorchServe context with system_properties, manifest,
                 and model_yaml_config.

        Raises:
            RuntimeError: If required Neuron cores are not available.
        """
        ...

    def preprocess(self, requests):
        """
        Tokenize a batch of input texts.

        Args:
            requests (list): List of dicts with "data" or "body" keys.

        Returns:
            tuple: (input_ids_batch, attention_mask_batch) as tensors.
        """
        ...

    def encode_input_text(self, input_text):
        """
        Encode a single text string using AutoTokenizer.

        Args:
            input_text (str|bytes): Raw input text.

        Returns:
            tuple: (input_ids, attention_mask) tensors.
        """
        ...

    def inference(self, input_batch):
        """
        Generate text using the compiled Neuron model.

        Pads partial batches with zeros to match batchSize,
        calls model.sample(), decodes tokens, and truncates
        padding results.

        Args:
            input_batch (tuple): (input_ids_batch, attention_mask_batch).

        Returns:
            list: List of generated text strings.
        """
        ...

    def postprocess(self, inference_output):
        """
        Return inference output unchanged.

        Args:
            inference_output (list): Generated text strings.

        Returns:
            list: Same as input.
        """
        ...

Import

# Handler imports
import torch
import torch_neuronx
from transformers import AutoTokenizer
from transformers_neuronx.opt.model import OPTForSampling
from ts.torch_handler.base_handler import BaseHandler
from abc import ABC

I/O Contract

Method Input Output Notes
initialize(ctx) Context with model_yaml_config containing handler.tp_degree, handler.amp, handler.model_name, handler.manual_seed, handler.max_length, batchSize None (sets self.model, self.tokenizer, self.initialized = True) Called once at worker startup; validates Neuron core count
preprocess(requests) list[dict] with "data" or "body" string values tuple(Tensor, Tensor) -- (input_ids_batch, attention_mask_batch) Iterates over requests calling encode_input_text, then torch.cat
encode_input_text(input_text) str or bytes tuple(Tensor, Tensor) -- (input_ids, attention_mask) Uses tokenizer.encode_plus with max_length, padding, truncation
inference(input_batch) tuple(Tensor, Tensor) from preprocess list[str] -- generated text strings Pads batch to self.batch_size with nn.ConstantPad1d if partial
postprocess(inference_output) list[str] from inference list[str] Identity passthrough

Usage Examples

Example 1: Initialization with Tensor Parallelism

# From inf2_handler.py L26-74: initialize() sets up Neuron cores and compiles model
def initialize(self, ctx):
    self.manifest = ctx.manifest
    properties = ctx.system_properties
    model_dir = properties.get("model_dir")

    seed = ctx.model_yaml_config["handler"]["manual_seed"]
    tp_degree = ctx.model_yaml_config["handler"]["tp_degree"]
    amp = ctx.model_yaml_config["handler"]["amp"]
    model_name = ctx.model_yaml_config["handler"]["model_name"]

    # Allocate Neuron cores
    os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
    try:
        num_neuron_cores_available = (
            torch_neuronx.xla_impl.data_parallel.device_count()
        )
        assert num_neuron_cores_available >= int(tp_degree)
    except (RuntimeError, AssertionError) as error:
        raise RuntimeError(
            "Required number of neuron cores for tp_degree "
            + str(tp_degree)
            + " are not available: "
            + str(error)
        )

    torch.manual_seed(seed)
    self.tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors="pt")
    self.tokenizer.pad_token = self.tokenizer.eos_token

    self.batch_size = ctx.model_yaml_config["batchSize"]
    self.model = OPTForSampling.from_pretrained(
        model_dir, batch_size=self.batch_size, tp_degree=tp_degree, amp=amp
    )
    self.model.to_neuron()
    self.max_length = ctx.model_yaml_config["handler"]["max_length"]
    self.initialized = True

Example 2: Dynamic Batch Padding During Inference

# From inf2_handler.py L119-152: inference() pads partial batches
def inference(self, input_batch):
    input_ids_batch = input_batch[0]

    num_inferences = len(input_ids_batch)
    padding = self.batch_size - num_inferences
    if padding > 0:
        pad = torch.nn.ConstantPad1d((0, 0, 0, padding), value=0)
        input_ids_batch = pad(input_ids_batch)

    outputs = self.model.sample(input_ids_batch, self.max_length)

    inferences = self.tokenizer.batch_decode(
        outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    inferences = inferences[:num_inferences]
    return inferences

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment