Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq LlavaStreamGenerator

From Leeroopedia
Knowledge Sources
Domains NLP, Inference
Last Updated 2026-02-15 00:00 GMT

Overview

Token-by-token streaming text generation for LLaVA multimodal models with image support, logits processing, and timing metrics.

Description

This module provides streaming inference for LLaVA (Large Language and Vision Assistant) models, enabling real-time token generation with multimodal (text + image) input.

prepare_logits_processor constructs a LogitsProcessorList from the HuggingFace Transformers library, chaining together TemperatureLogitsWarper, TopPLogitsWarper, and TopKLogitsWarper based on the provided generation parameters. Temperature values of 0.0 or 1.0 are treated as no-ops and skipped.

tokenizer_image_token splits a prompt string on the literal <image> placeholder, tokenizes each chunk separately, and inserts the IMAGE_TOKEN_INDEX marker between them. This enables the model to know where image embeddings should be injected in the token sequence. The function preserves the BOS token if present and supports returning raw lists or PyTorch tensors.

LlavaStreamGenerator is the main generator function, decorated with @torch.inference_mode(). It accepts a model, tokenizer, text input, image tensor, and generation parameters. On the first iteration (context stage), it processes the full input sequence including image tokens; on subsequent iterations, it feeds only the last generated token. It supports two code paths: one for LLaMA/LLaVA-class models using start_pos and chunk_prefilling, and another for other architectures using HuggingFace-style past_key_values caching. At each step it applies logits processing (temperature scaling, top-p, top-k filtering), performs either greedy argmax or multinomial sampling, and yields a dictionary containing the decoded text, token usage statistics, and finish reason. Timing metrics are tracked via global variables: context_tokens, context_time, and generation_time_list, which are reported in the final yield. After generation completes, GPU memory is explicitly freed via gc.collect() and torch.cuda.empty_cache().

Usage

Use this generator when performing streaming inference with LLaVA models that combine text and image inputs. It is invoked from the VILA 1.0 demo (vila10_demo.py) and VILA 1.5 demo (vila15_demo.py) scripts. Pass the generator output to stream_output for real-time console display.

Code Reference

Source Location

Signature

def prepare_logits_processor(
    temperature: float,
    repetition_penalty: float,
    top_p: float,
    top_k: int,
    min_tokens_to_keep: int = 1,
) -> LogitsProcessorList:

def tokenizer_image_token(
    prompt,
    tokenizer,
    image_token_index=tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_TOKEN_IDX,
    return_tensors=None,
):

@torch.inference_mode()
def LlavaStreamGenerator(
    model,
    tokenizer,
    input: str,
    start_pos: int,
    gen_params: dict,
    device: str = "cuda:0",
    stream_interval: int = 1,
    echo: bool = False,
    stop_token_ids=[],
    image_tensor: Optional[torch.FloatTensor] = None,
    chunk_prefilling: bool = False,
):

Import

from tinychat.stream_generators.llava_stream_gen import (
    prepare_logits_processor,
    tokenizer_image_token,
    LlavaStreamGenerator,
)

I/O Contract

prepare_logits_processor

Parameter Type Description
temperature float Sampling temperature; values near 0 trigger greedy decoding
repetition_penalty float Penalty for repeated tokens (currently disabled for VILA)
top_p float Nucleus sampling threshold (0.0-1.0)
top_k int Top-k filtering count; 0 disables
min_tokens_to_keep int Minimum tokens to retain during top-k (default: 1)
Returns Type Description
processor_list LogitsProcessorList Chained logits processors for sampling

tokenizer_image_token

Parameter Type Description
prompt str Input prompt containing <image> placeholders
tokenizer PreTrainedTokenizer HuggingFace tokenizer instance
image_token_index int Token ID to insert at image positions
return_tensors str or None If "pt", returns a PyTorch LongTensor
Returns Type Description
input_ids list[int] or torch.Tensor Token IDs with image markers inserted

LlavaStreamGenerator

Parameter Type Description
model nn.Module LLaVA or VILA model instance
tokenizer PreTrainedTokenizer Tokenizer for encoding/decoding
input str Text prompt (may contain <image> tokens)
start_pos int Starting position for KV cache (chunk prefilling)
gen_params dict-like Generation params with attrs: temp, repeat_penalty, top_p, top_k, n_vocab, n_predict
device str CUDA device string (default: "cuda:0")
stream_interval int Yield decoded text every N tokens (default: 1)
echo bool If True, include input tokens in output
stop_token_ids list[int] Additional stop token IDs beyond EOS
image_tensor Optional[torch.FloatTensor] Preprocessed image tensor
chunk_prefilling bool Enable chunk prefilling for multi-turn speedup
Yields Type Description
result dict Keys: "text" (str), "usage" (dict with prompt_tokens, completion_tokens, total_tokens), "finish_reason" (None/"stop"/"length"), "timing" (dict or None with context_tokens, context_time, total_tokens, generation_time_list)

Usage Examples

from tinychat.stream_generators.llava_stream_gen import (
    LlavaStreamGenerator,
    prepare_logits_processor,
)

# Basic streaming generation with a LLaVA model
for output in LlavaStreamGenerator(
    model=llava_model,
    tokenizer=tokenizer,
    input="<image>\nDescribe this image in detail.",
    start_pos=0,
    gen_params=gen_params,
    device="cuda:0",
    stop_token_ids=[],
    image_tensor=image_tensor,
    chunk_prefilling=False,
):
    if output["finish_reason"] is None:
        print(output["text"], end="", flush=True)
    else:
        # Final output with timing info
        timing = output["timing"]
        print(f"\nContext: {timing['context_tokens']} tokens in {timing['context_time']:.2f}s")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment