Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pytorch Serve Send Intermediate Response

From Leeroopedia
Field Value
Page Type Implementation (API Doc)
Title Send Intermediate Response
Implements Principle:Pytorch_Serve_Streaming_Inference
Source ts/handler_utils/utils.py
Repository TorchServe
Last Updated 2026-02-13 00:00 GMT

Overview

send_intermediate_predict_response() is the API function that enables streaming inference in TorchServe handlers. It sends an intermediate inference result to the client while the handler continues generating more output. The function is rank-aware: in distributed inference scenarios, it only sends data from the rank 0 process, returning None on all other ranks. It uses the context's client socket (context.cl_socket) to send the response directly, with the ts_stream_next=True flag signaling that more data will follow.

Description

The send_intermediate_predict_response() function provides the core streaming mechanism:

1. Rank Guard: The function first checks os.getenv("LOCAL_RANK", 0). If the value is not "0", it returns None immediately. This prevents duplicate streaming output in multi-GPU distributed inference.

2. Message Creation: It calls create_predict_response() from the OTF (Open Transport Format) message handler module with ts_stream_next=True. This flag is embedded in the binary protocol message to tell the TorchServe frontend that this is an intermediate response and more data will follow.

3. Socket Transmission: The serialized message is sent directly via context.cl_socket.sendall(msg), the socket connection between the backend worker and the TorchServe frontend.

The function is designed to be called multiple times during a single inference request. Each call sends one intermediate result. After all intermediate results have been sent, the handler returns the final result through the normal return path (without the ts_stream_next flag), signaling the end of the stream.

The create_predict_response() function (in ts/protocol/otf_message_handler.py) also has its own rank guard -- it returns None if LOCAL_RANK != 0 -- providing a defense-in-depth approach to preventing duplicate responses.

Usage

Code Reference

Source Location: ts/handler_utils/utils.py (lines 36-42)

Signature:

def send_intermediate_predict_response(
    ret, req_id_map, message, code, context: Context
):
    if str(os.getenv("LOCAL_RANK", 0)) != "0":
        return None
    msg = create_predict_response(ret, req_id_map, message, code, context, True)
    context.cl_socket.sendall(msg)

Import:

from ts.handler_utils.utils import send_intermediate_predict_response

Note: TorchServe v1.0.0 deprecated the previous import path. The old import from ts.protocol.otf_message_handler import send_intermediate_predict_response should be replaced with the current one from ts.handler_utils.utils.

Internal Dependency:

from ts.protocol.otf_message_handler import create_predict_response

I/O Contract

Function Parameters:

Parameter Type Description
ret list List of response values (one per request in the batch). Each element is the intermediate result to send.
req_id_map dict Mapping of request IDs. Typically context.request_ids.
message str Status message (e.g., "Intermediate Prediction success").
code int HTTP status code (typically 200 for success).
context Context TorchServe context object. Must have cl_socket attribute set.

Return Value:

  • None -- The function returns None in all cases (on rank 0, after sending; on other ranks, immediately).

Side Effects:

  • On rank 0: Sends binary protocol message via context.cl_socket.sendall().
  • On non-rank-0 processes: No side effects.

Protocol Details:

  • The message is created with ts_stream_next=True, which is encoded in the binary protocol as a flag indicating more data will follow.
  • The frontend translates this into HTTP 1.1 chunked encoding or gRPC stream messages.
  • The final response (returned normally by the handler) does not have ts_stream_next set, signaling stream completion.

Usage Examples

Basic streaming handler:

from ts.handler_utils.utils import send_intermediate_predict_response

def handle(data, context):
    if type(data) is list:
        for i in range(3):
            send_intermediate_predict_response(
                ["intermediate_response"],
                context.request_ids,
                "Intermediate Prediction success",
                200,
                context,
            )
        return ["hello world "]

Streaming LLM token generation:

from ts.handler_utils.utils import send_intermediate_predict_response

class StreamingLLMHandler(BaseHandler):
    def inference(self, input_batch):
        input_ids = input_batch
        generated_tokens = []

        for step in range(self.max_new_tokens - 1):
            outputs = self.model(input_ids)
            next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
            input_ids = torch.cat([input_ids, next_token], dim=-1)

            token_text = self.tokenizer.decode(next_token[0])
            generated_tokens.append(token_text)

            # Send intermediate token
            send_intermediate_predict_response(
                [token_text],
                self.context.request_ids,
                "Intermediate Prediction success",
                200,
                self.context,
            )

        # Generate final token and return
        outputs = self.model(input_ids)
        final_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        final_text = self.tokenizer.decode(final_token[0])
        return [final_text]

Client-side streaming consumption (HTTP):

import requests

response = requests.post(
    "http://localhost:8080/predictions/my_model",
    data="input text",
    stream=True,
)
assert response.headers['Transfer-Encoding'] == 'chunked'

for chunk in response.iter_content(chunk_size=None):
    if chunk:
        print(chunk.decode("utf-8"), end="", flush=True)

Client-side streaming consumption (gRPC):

import grpc
from ts_scripts import inference_pb2, inference_pb2_grpc

channel = grpc.insecure_channel("localhost:7070")
stub = inference_pb2_grpc.InferenceAPIsServiceStub(channel)

request = inference_pb2.PredictionsRequest(
    model_name="my_model",
    input={"data": b"input text"},
)

for response in stub.StreamPredictions(request):
    print(response.prediction.decode("utf-8"), end="", flush=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment