Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mit han lab Llm awq NVILAStreamGenerator

From Leeroopedia
Knowledge Sources
Domains NLP, Inference
Last Updated 2026-02-15 00:00 GMT

Overview

Token-by-token streaming text generation for NVILA multimodal models, delegating forward passes to the model's own stream_gen method and handling media and media configuration separately.

Description

NVILAStreamGenerator is a generator function decorated with @torch.inference_mode() that produces streaming text output for NVILA (NVIDIA VILA) models. Unlike the LLaVA stream generator which manages model forward passes directly, this generator delegates to model.stream_gen(), which handles media embedding internally and returns both logits and the consumed sequence length.

The generator reuses prepare_logits_processor from llava_stream_gen to build the logits processing pipeline (temperature, top-p, top-k). On the first iteration, the full tokenized input and media/media_cfg are passed to model.stream_gen(); on subsequent iterations, only the single last-generated token is passed with media set to None. The start_pos is incremented by the returned length each step, enabling efficient KV cache reuse.

The function includes robustness checks for numerical stability: it detects Inf values in logits and prints diagnostic information, and it checks for Inf/NaN in the softmax probabilities, saving the problematic tensor to disk and exiting gracefully if detected.

Token generation uses either greedy argmax (when temperature or top_p is near zero) or multinomial sampling from the softmax distribution. Results are yielded at configurable intervals (stream_interval, default 2), with each yield containing decoded text, usage statistics, and optional timing information. The final yield includes a timing dict with context_tokens, context_time, total_tokens, and generation_time_list for performance analysis.

For multi-turn conversations with chunk_prefilling enabled, the input is prepended with <|im_start|> when start_pos is non-zero, aligning with the Qwen2 chat template format.

Usage

Use this generator for streaming inference with NVILA models in the nvila_demo.py interactive demo. It is selected as the stream generator and invoked in the main chat loop.

Code Reference

Source Location

Signature

@torch.inference_mode()
def NVILAStreamGenerator(
    model,
    gen_params,
    input: str,
    media=None,
    media_cfg=None,
    start_pos: int = 0,
    device: str = "cuda:0",
    stream_interval: int = 2,
    echo: bool = False,
    stop_token_ids=[],
    image_tensor: Optional[torch.FloatTensor] = None,
    chunk_prefilling: bool = False,
    quant_llm: bool = False,
):

Import

from tinychat.stream_generators.NVILA_stream_gen import NVILAStreamGenerator

I/O Contract

Inputs

Parameter Type Description
model nn.Module NVILA model with .stream_gen() and .tokenizer attributes
gen_params object Generation params with attrs: temp, repeat_penalty, top_p, top_k, n_vocab, n_predict
input str Text prompt string
media dict or None Media dictionary (e.g., preprocessed image/video tensors)
media_cfg dict or None Media configuration metadata
start_pos int Starting KV cache position (default: 0)
device str CUDA device string (default: "cuda:0")
stream_interval int Yield decoded text every N tokens (default: 2)
echo bool If True, include input tokens in output
stop_token_ids list[int] Additional stop token IDs beyond EOS
image_tensor Optional[torch.FloatTensor] Legacy parameter for compatibility (unused)
chunk_prefilling bool Enable chunk prefilling for multi-turn speedup
quant_llm bool Whether the LLM backbone is quantized

Outputs

Yields Type Description
result dict Keys: "text" (str), "usage" (dict with prompt_tokens, completion_tokens, total_tokens), "finish_reason" (None/"stop"/"length"), "timing" (dict or None with context_tokens, context_time, total_tokens, generation_time_list)

Usage Examples

from tinychat.stream_generators.NVILA_stream_gen import NVILAStreamGenerator

# Streaming generation with NVILA model
for output in NVILAStreamGenerator(
    model=nvila_model,
    gen_params=gen_params,
    input="Describe this video in detail.",
    media=media,
    media_cfg=media_cfg,
    start_pos=0,
    device="cuda:0",
    stop_token_ids=[],
    chunk_prefilling=True,
    quant_llm=True,
):
    if output["finish_reason"] is None:
        print(output["text"], end="", flush=True)
    else:
        timing = output["timing"]
        total = timing["total_tokens"]
        gen_times = timing["generation_time_list"]
        avg_speed = len(gen_times) / sum(gen_times) if gen_times else 0
        print(f"\nGeneration speed: {avg_speed:.1f} tokens/s")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment