Implementation:Mit han lab Llm awq Model Worker
| Knowledge Sources | |
|---|---|
| Domains | Serving, Inference |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
The ModelWorker class loads a LLaVA model with optional AWQ 4-bit quantization and serves streaming inference requests as a FastAPI server, registering with a central controller for distributed orchestration.
Description
This module implements the primary model worker for the TinyChat serving infrastructure. The ModelWorker class handles the full model lifecycle: loading a LlavaLlamaForCausalLM model from a specified path, initializing the vision tower and image processor, applying either W16A16 full-precision loading (via load_checkpoint_and_dispatch) or W4A16 AWQ quantization (via load_awq_model with make_quant_attn and make_quant_norm optimizations), and registering with the controller. The tokenizer is loaded from the model path using AutoTokenizer. A background thread sends periodic heartbeats to the controller via send_heart_beat(), which re-registers the worker if the controller reports it as unknown. Concurrency is managed through an asyncio Semaphore initialized with the --limit-model-concurrency parameter (default 5). The generate_stream(params) method processes incoming requests by decoding base64 images, applying image processing via process_images, replacing image token placeholders, and invoking LlavaStreamGenerator to produce a token-by-token streaming response. The generate_stream_gate() wrapper provides error handling for ValueError, CudaError, and general exceptions. The FastAPI app exposes two POST endpoints: /worker_generate_stream for streaming text generation and /worker_get_status for reporting model name, speed, and queue length.
Usage
Run this module as a standalone FastAPI server process. Each worker loads one model and registers with the controller. Multiple workers can run in parallel on different ports or machines.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/serve/model_worker.py
- Lines: 1-434
Signature
class ModelWorker:
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
no_register: bool,
model_type: str,
model_path: str,
model_name: str,
quant_path: str,
precision: str,
device: str,
): ...
def register_to_controller(self) -> None: ...
def send_heart_beat(self) -> None: ...
def get_queue_length(self) -> int: ...
def get_status(self) -> dict: ...
@torch.inference_mode()
def generate_stream(self, params: dict) -> Generator[bytes, None, None]: ...
def generate_stream_gate(self, params: dict) -> Generator[bytes, None, None]: ...
Import
# Run as a standalone model worker:
# python -m tinychat.serve.model_worker \
# --model-path /path/to/llava-model \
# --quant-path /path/to/awq-weights.pt \
# --precision W4A16 \
# --controller-address http://localhost:21001 \
# --port 21002
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| controller_addr | str | Yes | URL of the controller (default: http://localhost:21001) |
| worker_addr | str | Yes | This worker's URL for the controller to reach back (default: http://localhost:21002) |
| model_type | str | Yes | Base language model type, e.g. "LLaMa" |
| model_path | str | Yes | Path or HuggingFace ID of the model to load |
| quant_path | str | No | Path to AWQ quantized weights (required for W4A16 precision) |
| precision | str | Yes | Quantization precision: "W16A16" (full) or "W4A16" (AWQ 4-bit) |
| device | str | Yes | Target device, e.g. "cuda" |
| params.prompt | str | Yes | The text prompt including image token placeholders |
| params.images | List[str] | No | Base64-encoded images corresponding to <image> tokens in the prompt |
| params.temperature | float | No | Sampling temperature (default: 1.0) |
| params.top_p | float | No | Top-p nucleus sampling parameter (default: 1.0) |
| params.max_new_tokens | int | No | Maximum tokens to generate (default: 256, capped at 1024) |
Outputs
| Name | Type | Description |
|---|---|---|
| streaming_response | StreamingResponse | Newline-delimited JSON chunks with "text" and "error_code" fields, null-byte delimited |
| status | dict | Worker status with keys "model_names" (List[str]), "speed" (int), "queue_length" (int) |
Usage Examples
Launching a W4A16 AWQ Worker
# python -m tinychat.serve.model_worker \
# --model-path /models/llava-v1.5-7b \
# --quant-path /models/llava-v1.5-7b-awq.pt \
# --precision W4A16 \
# --controller-address http://localhost:21001 \
# --worker-address http://localhost:21002 \
# --port 21002
Launching a Full-Precision Worker
# python -m tinychat.serve.model_worker \
# --model-path /models/llava-v1.5-7b \
# --precision W16A16 \
# --controller-address http://localhost:21001 \
# --port 21003
Sending a Request to the Worker
import requests
import json
response = requests.post(
"http://localhost:21002/worker_generate_stream",
json={
"model": "llava-v1.5-7b-4bit-AWQ",
"prompt": "USER: What is in this image?\n<image> ASSISTANT:",
"images": ["<base64-encoded-image-string>"],
"temperature": 0.2,
"top_p": 1.0,
"max_new_tokens": 512,
},
stream=True,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
print(data["text"])