Implementation:Lm sys FastChat ModelWorker Load And Generate
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Repository | lm-sys/FastChat |
| Domain | Machine Learning Inference, Distributed Systems, Streaming Generation |
| Knowledge Sources | Source code analysis of fastchat/serve/model_worker.py, fastchat/serve/base_model_worker.py, fastchat/serve/inference.py
|
| Last Updated | 2026-02-07 14:00 GMT |
| Implements | Principle:Lm_sys_FastChat_Model_Worker_Inference |
Overview
This page documents the ModelWorker class, its BaseModelWorker parent, and the generate_stream function that together implement model inference in the FastChat distributed serving system. The ModelWorker loads a model, registers with the controller, maintains heartbeats, and serves inference requests via a FastAPI HTTP interface. The generate_stream function provides the core autoregressive token generation loop with logits processing and streaming output.
Description
The ModelWorker extends BaseModelWorker with concrete model loading and generation logic for HuggingFace Transformers models. On initialization, it loads the model and tokenizer via the adapter pattern, obtains the appropriate generation function for the model backend, and starts heartbeat communication with the controller. Inference requests are handled through the generate_stream_gate method, which wraps the underlying generation function with error handling and output formatting.
The generate_stream function in inference.py implements the token-by-token autoregressive generation loop. It uses KV-cache for efficient decoding, applies a configurable logits processor chain (temperature, repetition penalty, top-p, top-k), and yields partial results at configurable intervals for streaming.
Usage
Start a model worker from the command line:
python3 -m fastchat.serve.model_worker \
--model-path lmsys/vicuna-7b-v1.5 \
--controller-address http://localhost:21001 \
--port 21002 \
--device cuda \
--num-gpus 1
Use programmatically:
from fastchat.serve.model_worker import ModelWorker
worker = ModelWorker(
controller_addr="http://localhost:21001",
worker_addr="http://localhost:21002",
worker_id="worker-abc123",
model_path="lmsys/vicuna-7b-v1.5",
model_names=["vicuna-7b-v1.5"],
limit_worker_concurrency=5,
no_register=False,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
)
Code Reference
Source Location
| Component | File | Lines |
|---|---|---|
| BaseModelWorker class | fastchat/serve/base_model_worker.py |
L27-177 |
| ModelWorker class | fastchat/serve/model_worker.py |
L38-300 |
| create_model_worker factory | fastchat/serve/model_worker.py |
L303-410 |
| generate_stream function | fastchat/serve/inference.py |
L61-316 |
| prepare_logits_processor | fastchat/serve/inference.py |
L45-58 |
| FastAPI worker endpoints | fastchat/serve/base_model_worker.py |
L196-241 |
Signature
class BaseModelWorker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
multimodal: bool = False,
) -> None: ...
def init_heart_beat(self) -> None: ...
def register_to_controller(self) -> None: ...
def send_heart_beat(self) -> None: ...
def get_queue_length(self) -> int: ...
def get_status(self) -> dict: ...
def count_token(self, params: dict) -> dict: ...
def get_conv_template(self) -> dict: ...
class ModelWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
device: str,
num_gpus: int,
max_gpu_memory: str,
revision: str = None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
debug: bool = False,
**kwargs,
) -> None: ...
def generate_stream_gate(self, params: dict) -> Generator[bytes, None, None]: ...
def generate_gate(self, params: dict) -> dict: ...
def get_embeddings(self, params: dict) -> dict: ...
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
) -> Generator[Dict, None, None]: ...
def prepare_logits_processor(
temperature: float,
repetition_penalty: float,
top_p: float,
top_k: int,
) -> LogitsProcessorList: ...
Import
from fastchat.serve.model_worker import ModelWorker
from fastchat.serve.base_model_worker import BaseModelWorker
from fastchat.serve.inference import generate_stream, prepare_logits_processor
I/O Contract
CLI Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
--model-path |
str | (required) | Path or HuggingFace model ID for the model to load |
--host |
str | "localhost" |
Host address to bind the worker server |
--port |
int | 21002 |
Port number for the worker server |
--worker-address |
str | "http://localhost:21002" |
The externally reachable address of this worker |
--controller-address |
str | "http://localhost:21001" |
Address of the controller to register with |
--device |
str | "cuda" |
Device to load the model onto (cuda, cpu, mps, xpu, npu) |
--num-gpus |
int | 1 |
Number of GPUs to distribute the model across |
--max-gpu-memory |
str | None | Maximum GPU memory per device (e.g., "13GiB")
|
--load-8bit |
flag | False |
Load model in 8-bit quantization |
--limit-worker-concurrency |
int | 5 |
Maximum concurrent requests to prevent OOM |
--stream-interval |
int | 2 |
Yield streaming output every N tokens |
--model-names |
str | None | Comma-separated display names for the model |
--conv-template |
str | None | Conversation prompt template name |
--no-register |
flag | False |
Skip controller registration (for standalone use) |
--seed |
int | None | Fixed random seed for reproducible generation |
--ssl |
flag | False |
Enable SSL (requires SSL_KEYFILE and SSL_CERTFILE env vars)
|
FastAPI Worker Endpoints
| Method | Endpoint | Description |
|---|---|---|
| POST | /worker_generate_stream |
Stream token generation, returns StreamingResponse with \0-delimited JSON chunks
|
| POST | /worker_generate |
Non-streaming generation, returns complete JSON response |
| POST | /worker_get_embeddings |
Compute text embeddings |
| POST | /worker_get_status |
Return worker status (model_names, speed, queue_length) |
| POST | /count_token |
Count tokens in a prompt |
| POST | /worker_get_conv_template |
Return the conversation template for the model |
| POST | /model_details |
Return model details including context_length |
generate_stream Input Parameters (via params dict)
| Key | Type | Default | Description |
|---|---|---|---|
prompt |
str | (required) | The input prompt text |
temperature |
float | 1.0 | Sampling temperature (0 = greedy) |
repetition_penalty |
float | 1.0 | Penalty for repeated tokens (>1.0 to penalize) |
top_p |
float | 1.0 | Nucleus sampling threshold |
top_k |
int | -1 | Top-k sampling (-1 to disable) |
max_new_tokens |
int | 256 | Maximum number of tokens to generate |
logprobs |
int or None | None | Whether to return log probabilities |
echo |
bool | True | Whether to include the prompt in the output |
stop |
str or List[str] or None | None | Stop string(s) to terminate generation |
stop_token_ids |
List[int] or None | None | Stop token IDs (EOS is always included) |
generate_stream Output (yielded dicts)
{
"text": str, # Decoded output text so far
"logprobs": { # Optional, present when logprobs requested
"text_offset": List[int],
"tokens": List[str],
"token_logprobs": List[Optional[float]],
"top_logprobs": List[Dict[str, float]],
},
"usage": {
"prompt_tokens": int,
"completion_tokens": int,
"total_tokens": int,
},
"finish_reason": None | "stop" | "length",
}
Usage Examples
Starting a Model Worker
# Basic single-GPU worker
python3 -m fastchat.serve.model_worker \
--model-path lmsys/vicuna-7b-v1.5
# Multi-GPU worker with 8-bit quantization
python3 -m fastchat.serve.model_worker \
--model-path lmsys/vicuna-13b-v1.5 \
--num-gpus 2 \
--load-8bit \
--port 21003
# Worker with custom model names and concurrency limit
python3 -m fastchat.serve.model_worker \
--model-path lmsys/vicuna-7b-v1.5 \
--model-names "vicuna-7b,vicuna" \
--limit-worker-concurrency 3
Sending Requests to a Worker Directly
import requests
import json
WORKER_URL = "http://localhost:21002"
# Streaming generation
params = {
"model": "vicuna-7b-v1.5",
"prompt": "What is the capital of France?",
"temperature": 0.7,
"max_new_tokens": 128,
"stop": None,
}
response = requests.post(
f"{WORKER_URL}/worker_generate_stream",
json=params,
stream=True,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
print(data["text"], end="\r")
# Get worker status
status = requests.post(f"{WORKER_URL}/worker_get_status").json()
print(f"Models: {status['model_names']}, Queue: {status['queue_length']}")
Related Pages
- Principle:Lm_sys_FastChat_Model_Worker_Inference
- Principle:Lm_sys_FastChat_Model_Worker_Inference -- The principle this implementation realizes
- Implementation:Lm_sys_FastChat_Controller_Dispatch -- The controller that dispatches requests to this worker
- Implementation:Lm_sys_FastChat_OpenAI_API_Server -- The API server that forwards requests to workers
- Environment:Lm_sys_FastChat_GPU_CUDA_Inference
- Heuristic:Lm_sys_FastChat_GPU_Memory_Allocation_Strategy
- Heuristic:Lm_sys_FastChat_Greedy_Decoding_Temperature_Threshold