Overview
Token-by-token streaming text generation for LLaVA multimodal models with image support, logits processing, and timing metrics.
Description
This module provides streaming inference for LLaVA (Large Language and Vision Assistant) models, enabling real-time token generation with multimodal (text + image) input.
prepare_logits_processor constructs a LogitsProcessorList from the HuggingFace Transformers library, chaining together TemperatureLogitsWarper, TopPLogitsWarper, and TopKLogitsWarper based on the provided generation parameters. Temperature values of 0.0 or 1.0 are treated as no-ops and skipped.
tokenizer_image_token splits a prompt string on the literal <image> placeholder, tokenizes each chunk separately, and inserts the IMAGE_TOKEN_INDEX marker between them. This enables the model to know where image embeddings should be injected in the token sequence. The function preserves the BOS token if present and supports returning raw lists or PyTorch tensors.
LlavaStreamGenerator is the main generator function, decorated with @torch.inference_mode(). It accepts a model, tokenizer, text input, image tensor, and generation parameters. On the first iteration (context stage), it processes the full input sequence including image tokens; on subsequent iterations, it feeds only the last generated token. It supports two code paths: one for LLaMA/LLaVA-class models using start_pos and chunk_prefilling, and another for other architectures using HuggingFace-style past_key_values caching. At each step it applies logits processing (temperature scaling, top-p, top-k filtering), performs either greedy argmax or multinomial sampling, and yields a dictionary containing the decoded text, token usage statistics, and finish reason. Timing metrics are tracked via global variables: context_tokens, context_time, and generation_time_list, which are reported in the final yield. After generation completes, GPU memory is explicitly freed via gc.collect() and torch.cuda.empty_cache().
Usage
Use this generator when performing streaming inference with LLaVA models that combine text and image inputs. It is invoked from the VILA 1.0 demo (vila10_demo.py) and VILA 1.5 demo (vila15_demo.py) scripts. Pass the generator output to stream_output for real-time console display.
Code Reference
Source Location
Signature
def prepare_logits_processor(
temperature: float,
repetition_penalty: float,
top_p: float,
top_k: int,
min_tokens_to_keep: int = 1,
) -> LogitsProcessorList:
def tokenizer_image_token(
prompt,
tokenizer,
image_token_index=tinychat.utils.constants.LLAVA_DEFAULT_IMAGE_TOKEN_IDX,
return_tensors=None,
):
@torch.inference_mode()
def LlavaStreamGenerator(
model,
tokenizer,
input: str,
start_pos: int,
gen_params: dict,
device: str = "cuda:0",
stream_interval: int = 1,
echo: bool = False,
stop_token_ids=[],
image_tensor: Optional[torch.FloatTensor] = None,
chunk_prefilling: bool = False,
):
Import
from tinychat.stream_generators.llava_stream_gen import (
prepare_logits_processor,
tokenizer_image_token,
LlavaStreamGenerator,
)
I/O Contract
prepare_logits_processor
| Parameter |
Type |
Description
|
| temperature |
float |
Sampling temperature; values near 0 trigger greedy decoding
|
| repetition_penalty |
float |
Penalty for repeated tokens (currently disabled for VILA)
|
| top_p |
float |
Nucleus sampling threshold (0.0-1.0)
|
| top_k |
int |
Top-k filtering count; 0 disables
|
| min_tokens_to_keep |
int |
Minimum tokens to retain during top-k (default: 1)
|
| Returns |
Type |
Description
|
| processor_list |
LogitsProcessorList |
Chained logits processors for sampling
|
tokenizer_image_token
| Parameter |
Type |
Description
|
| prompt |
str |
Input prompt containing <image> placeholders
|
| tokenizer |
PreTrainedTokenizer |
HuggingFace tokenizer instance
|
| image_token_index |
int |
Token ID to insert at image positions
|
| return_tensors |
str or None |
If "pt", returns a PyTorch LongTensor
|
| Returns |
Type |
Description
|
| input_ids |
list[int] or torch.Tensor |
Token IDs with image markers inserted
|
LlavaStreamGenerator
| Parameter |
Type |
Description
|
| model |
nn.Module |
LLaVA or VILA model instance
|
| tokenizer |
PreTrainedTokenizer |
Tokenizer for encoding/decoding
|
| input |
str |
Text prompt (may contain <image> tokens)
|
| start_pos |
int |
Starting position for KV cache (chunk prefilling)
|
| gen_params |
dict-like |
Generation params with attrs: temp, repeat_penalty, top_p, top_k, n_vocab, n_predict
|
| device |
str |
CUDA device string (default: "cuda:0")
|
| stream_interval |
int |
Yield decoded text every N tokens (default: 1)
|
| echo |
bool |
If True, include input tokens in output
|
| stop_token_ids |
list[int] |
Additional stop token IDs beyond EOS
|
| image_tensor |
Optional[torch.FloatTensor] |
Preprocessed image tensor
|
| chunk_prefilling |
bool |
Enable chunk prefilling for multi-turn speedup
|
| Yields |
Type |
Description
|
| result |
dict |
Keys: "text" (str), "usage" (dict with prompt_tokens, completion_tokens, total_tokens), "finish_reason" (None/"stop"/"length"), "timing" (dict or None with context_tokens, context_time, total_tokens, generation_time_list)
|
Usage Examples
from tinychat.stream_generators.llava_stream_gen import (
LlavaStreamGenerator,
prepare_logits_processor,
)
# Basic streaming generation with a LLaVA model
for output in LlavaStreamGenerator(
model=llava_model,
tokenizer=tokenizer,
input="<image>\nDescribe this image in detail.",
start_pos=0,
gen_params=gen_params,
device="cuda:0",
stop_token_ids=[],
image_tensor=image_tensor,
chunk_prefilling=False,
):
if output["finish_reason"] is None:
print(output["text"], end="", flush=True)
else:
# Final output with timing info
timing = output["timing"]
print(f"\nContext: {timing['context_tokens']} tokens in {timing['context_time']:.2f}s")
Related Pages