Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat ModelWorker Load And Generate

From Leeroopedia


Field Value
Page Type Implementation (API Doc)
Repository lm-sys/FastChat
Domain Machine Learning Inference, Distributed Systems, Streaming Generation
Knowledge Sources Source code analysis of fastchat/serve/model_worker.py, fastchat/serve/base_model_worker.py, fastchat/serve/inference.py
Last Updated 2026-02-07 14:00 GMT
Implements Principle:Lm_sys_FastChat_Model_Worker_Inference

Overview

This page documents the ModelWorker class, its BaseModelWorker parent, and the generate_stream function that together implement model inference in the FastChat distributed serving system. The ModelWorker loads a model, registers with the controller, maintains heartbeats, and serves inference requests via a FastAPI HTTP interface. The generate_stream function provides the core autoregressive token generation loop with logits processing and streaming output.

Description

The ModelWorker extends BaseModelWorker with concrete model loading and generation logic for HuggingFace Transformers models. On initialization, it loads the model and tokenizer via the adapter pattern, obtains the appropriate generation function for the model backend, and starts heartbeat communication with the controller. Inference requests are handled through the generate_stream_gate method, which wraps the underlying generation function with error handling and output formatting.

The generate_stream function in inference.py implements the token-by-token autoregressive generation loop. It uses KV-cache for efficient decoding, applies a configurable logits processor chain (temperature, repetition penalty, top-p, top-k), and yields partial results at configurable intervals for streaming.

Usage

Start a model worker from the command line:

python3 -m fastchat.serve.model_worker \
    --model-path lmsys/vicuna-7b-v1.5 \
    --controller-address http://localhost:21001 \
    --port 21002 \
    --device cuda \
    --num-gpus 1

Use programmatically:

from fastchat.serve.model_worker import ModelWorker

worker = ModelWorker(
    controller_addr="http://localhost:21001",
    worker_addr="http://localhost:21002",
    worker_id="worker-abc123",
    model_path="lmsys/vicuna-7b-v1.5",
    model_names=["vicuna-7b-v1.5"],
    limit_worker_concurrency=5,
    no_register=False,
    device="cuda",
    num_gpus=1,
    max_gpu_memory=None,
)

Code Reference

Source Location

Component File Lines
BaseModelWorker class fastchat/serve/base_model_worker.py L27-177
ModelWorker class fastchat/serve/model_worker.py L38-300
create_model_worker factory fastchat/serve/model_worker.py L303-410
generate_stream function fastchat/serve/inference.py L61-316
prepare_logits_processor fastchat/serve/inference.py L45-58
FastAPI worker endpoints fastchat/serve/base_model_worker.py L196-241

Signature

class BaseModelWorker:
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        conv_template: str = None,
        multimodal: bool = False,
    ) -> None: ...

    def init_heart_beat(self) -> None: ...
    def register_to_controller(self) -> None: ...
    def send_heart_beat(self) -> None: ...
    def get_queue_length(self) -> int: ...
    def get_status(self) -> dict: ...
    def count_token(self, params: dict) -> dict: ...
    def get_conv_template(self) -> dict: ...

class ModelWorker(BaseModelWorker):
    def __init__(
        self,
        controller_addr: str,
        worker_addr: str,
        worker_id: str,
        model_path: str,
        model_names: List[str],
        limit_worker_concurrency: int,
        no_register: bool,
        device: str,
        num_gpus: int,
        max_gpu_memory: str,
        revision: str = None,
        dtype: Optional[torch.dtype] = None,
        load_8bit: bool = False,
        cpu_offloading: bool = False,
        gptq_config: Optional[GptqConfig] = None,
        awq_config: Optional[AWQConfig] = None,
        exllama_config: Optional[ExllamaConfig] = None,
        xft_config: Optional[XftConfig] = None,
        stream_interval: int = 2,
        conv_template: Optional[str] = None,
        embed_in_truncate: bool = False,
        seed: Optional[int] = None,
        debug: bool = False,
        **kwargs,
    ) -> None: ...

    def generate_stream_gate(self, params: dict) -> Generator[bytes, None, None]: ...
    def generate_gate(self, params: dict) -> dict: ...
    def get_embeddings(self, params: dict) -> dict: ...

@torch.inference_mode()
def generate_stream(
    model,
    tokenizer,
    params: Dict,
    device: str,
    context_len: int,
    stream_interval: int = 2,
    judge_sent_end: bool = False,
) -> Generator[Dict, None, None]: ...

def prepare_logits_processor(
    temperature: float,
    repetition_penalty: float,
    top_p: float,
    top_k: int,
) -> LogitsProcessorList: ...

Import

from fastchat.serve.model_worker import ModelWorker
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.inference import generate_stream, prepare_logits_processor

I/O Contract

CLI Parameters

Parameter Type Default Description
--model-path str (required) Path or HuggingFace model ID for the model to load
--host str "localhost" Host address to bind the worker server
--port int 21002 Port number for the worker server
--worker-address str "http://localhost:21002" The externally reachable address of this worker
--controller-address str "http://localhost:21001" Address of the controller to register with
--device str "cuda" Device to load the model onto (cuda, cpu, mps, xpu, npu)
--num-gpus int 1 Number of GPUs to distribute the model across
--max-gpu-memory str None Maximum GPU memory per device (e.g., "13GiB")
--load-8bit flag False Load model in 8-bit quantization
--limit-worker-concurrency int 5 Maximum concurrent requests to prevent OOM
--stream-interval int 2 Yield streaming output every N tokens
--model-names str None Comma-separated display names for the model
--conv-template str None Conversation prompt template name
--no-register flag False Skip controller registration (for standalone use)
--seed int None Fixed random seed for reproducible generation
--ssl flag False Enable SSL (requires SSL_KEYFILE and SSL_CERTFILE env vars)

FastAPI Worker Endpoints

Method Endpoint Description
POST /worker_generate_stream Stream token generation, returns StreamingResponse with \0-delimited JSON chunks
POST /worker_generate Non-streaming generation, returns complete JSON response
POST /worker_get_embeddings Compute text embeddings
POST /worker_get_status Return worker status (model_names, speed, queue_length)
POST /count_token Count tokens in a prompt
POST /worker_get_conv_template Return the conversation template for the model
POST /model_details Return model details including context_length

generate_stream Input Parameters (via params dict)

Key Type Default Description
prompt str (required) The input prompt text
temperature float 1.0 Sampling temperature (0 = greedy)
repetition_penalty float 1.0 Penalty for repeated tokens (>1.0 to penalize)
top_p float 1.0 Nucleus sampling threshold
top_k int -1 Top-k sampling (-1 to disable)
max_new_tokens int 256 Maximum number of tokens to generate
logprobs int or None None Whether to return log probabilities
echo bool True Whether to include the prompt in the output
stop str or List[str] or None None Stop string(s) to terminate generation
stop_token_ids List[int] or None None Stop token IDs (EOS is always included)

generate_stream Output (yielded dicts)

{
    "text": str,          # Decoded output text so far
    "logprobs": {         # Optional, present when logprobs requested
        "text_offset": List[int],
        "tokens": List[str],
        "token_logprobs": List[Optional[float]],
        "top_logprobs": List[Dict[str, float]],
    },
    "usage": {
        "prompt_tokens": int,
        "completion_tokens": int,
        "total_tokens": int,
    },
    "finish_reason": None | "stop" | "length",
}

Usage Examples

Starting a Model Worker

# Basic single-GPU worker
python3 -m fastchat.serve.model_worker \
    --model-path lmsys/vicuna-7b-v1.5

# Multi-GPU worker with 8-bit quantization
python3 -m fastchat.serve.model_worker \
    --model-path lmsys/vicuna-13b-v1.5 \
    --num-gpus 2 \
    --load-8bit \
    --port 21003

# Worker with custom model names and concurrency limit
python3 -m fastchat.serve.model_worker \
    --model-path lmsys/vicuna-7b-v1.5 \
    --model-names "vicuna-7b,vicuna" \
    --limit-worker-concurrency 3

Sending Requests to a Worker Directly

import requests
import json

WORKER_URL = "http://localhost:21002"

# Streaming generation
params = {
    "model": "vicuna-7b-v1.5",
    "prompt": "What is the capital of France?",
    "temperature": 0.7,
    "max_new_tokens": 128,
    "stop": None,
}

response = requests.post(
    f"{WORKER_URL}/worker_generate_stream",
    json=params,
    stream=True,
)

for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
    if chunk:
        data = json.loads(chunk.decode())
        print(data["text"], end="\r")

# Get worker status
status = requests.post(f"{WORKER_URL}/worker_get_status").json()
print(f"Models: {status['model_names']}, Queue: {status['queue_length']}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment