Implementation:Mit han lab Llm awq NVILAStreamGenerator
| Knowledge Sources | |
|---|---|
| Domains | NLP, Inference |
| Last Updated | 2026-02-15 00:00 GMT |
Overview
Token-by-token streaming text generation for NVILA multimodal models, delegating forward passes to the model's own stream_gen method and handling media and media configuration separately.
Description
NVILAStreamGenerator is a generator function decorated with @torch.inference_mode() that produces streaming text output for NVILA (NVIDIA VILA) models. Unlike the LLaVA stream generator which manages model forward passes directly, this generator delegates to model.stream_gen(), which handles media embedding internally and returns both logits and the consumed sequence length.
The generator reuses prepare_logits_processor from llava_stream_gen to build the logits processing pipeline (temperature, top-p, top-k). On the first iteration, the full tokenized input and media/media_cfg are passed to model.stream_gen(); on subsequent iterations, only the single last-generated token is passed with media set to None. The start_pos is incremented by the returned length each step, enabling efficient KV cache reuse.
The function includes robustness checks for numerical stability: it detects Inf values in logits and prints diagnostic information, and it checks for Inf/NaN in the softmax probabilities, saving the problematic tensor to disk and exiting gracefully if detected.
Token generation uses either greedy argmax (when temperature or top_p is near zero) or multinomial sampling from the softmax distribution. Results are yielded at configurable intervals (stream_interval, default 2), with each yield containing decoded text, usage statistics, and optional timing information. The final yield includes a timing dict with context_tokens, context_time, total_tokens, and generation_time_list for performance analysis.
For multi-turn conversations with chunk_prefilling enabled, the input is prepended with <|im_start|> when start_pos is non-zero, aligning with the Qwen2 chat template format.
Usage
Use this generator for streaming inference with NVILA models in the nvila_demo.py interactive demo. It is selected as the stream generator and invoked in the main chat loop.
Code Reference
Source Location
- Repository: Mit_han_lab_Llm_awq
- File: tinychat/stream_generators/NVILA_stream_gen.py
- Lines: 1-176
Signature
@torch.inference_mode()
def NVILAStreamGenerator(
model,
gen_params,
input: str,
media=None,
media_cfg=None,
start_pos: int = 0,
device: str = "cuda:0",
stream_interval: int = 2,
echo: bool = False,
stop_token_ids=[],
image_tensor: Optional[torch.FloatTensor] = None,
chunk_prefilling: bool = False,
quant_llm: bool = False,
):
Import
from tinychat.stream_generators.NVILA_stream_gen import NVILAStreamGenerator
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
| model | nn.Module | NVILA model with .stream_gen() and .tokenizer attributes |
| gen_params | object | Generation params with attrs: temp, repeat_penalty, top_p, top_k, n_vocab, n_predict |
| input | str | Text prompt string |
| media | dict or None | Media dictionary (e.g., preprocessed image/video tensors) |
| media_cfg | dict or None | Media configuration metadata |
| start_pos | int | Starting KV cache position (default: 0) |
| device | str | CUDA device string (default: "cuda:0") |
| stream_interval | int | Yield decoded text every N tokens (default: 2) |
| echo | bool | If True, include input tokens in output |
| stop_token_ids | list[int] | Additional stop token IDs beyond EOS |
| image_tensor | Optional[torch.FloatTensor] | Legacy parameter for compatibility (unused) |
| chunk_prefilling | bool | Enable chunk prefilling for multi-turn speedup |
| quant_llm | bool | Whether the LLM backbone is quantized |
Outputs
| Yields | Type | Description |
|---|---|---|
| result | dict | Keys: "text" (str), "usage" (dict with prompt_tokens, completion_tokens, total_tokens), "finish_reason" (None/"stop"/"length"), "timing" (dict or None with context_tokens, context_time, total_tokens, generation_time_list) |
Usage Examples
from tinychat.stream_generators.NVILA_stream_gen import NVILAStreamGenerator
# Streaming generation with NVILA model
for output in NVILAStreamGenerator(
model=nvila_model,
gen_params=gen_params,
input="Describe this video in detail.",
media=media,
media_cfg=media_cfg,
start_pos=0,
device="cuda:0",
stop_token_ids=[],
chunk_prefilling=True,
quant_llm=True,
):
if output["finish_reason"] is None:
print(output["text"], end="", flush=True)
else:
timing = output["timing"]
total = timing["total_tokens"]
gen_times = timing["generation_time_list"]
avg_speed = len(gen_times) / sum(gen_times) if gen_times else 0
print(f"\nGeneration speed: {avg_speed:.1f} tokens/s")