Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Haotian liu LLaVA Model Generate Multimodal

From Leeroopedia

Overview

Concrete tool for generating text responses from multimodal inputs using LLaVA's fused vision-language model. Combines visual embedding injection with autoregressive text decoding.

Sources

  • File: llava/eval/run_llava.py, Lines: L114-128 (generate call)
  • File: llava/model/llava_arch.py, Lines: L145-324 (prepare_inputs_labels_for_multimodal)

Signature

model.generate()

# Called on a LlavaLlamaForCausalLM instance:
output_ids = model.generate(
    input_ids: torch.Tensor,            # Tokenized prompt with IMAGE_TOKEN_INDEX
    images: torch.Tensor,               # Preprocessed image tensor
    image_sizes: List[Tuple[int, int]], # Original image dimensions
    do_sample: bool,                     # Whether to use sampling (True if temp > 0)
    temperature: float,                  # Sampling temperature
    top_p: float,                        # Nucleus sampling threshold
    num_beams: int,                      # Number of beams for beam search
    max_new_tokens: int,                 # Maximum tokens to generate
    use_cache: bool = True,              # Use KV cache for efficient generation
) -> torch.Tensor

prepare_inputs_labels_for_multimodal() (internal)

def prepare_inputs_labels_for_multimodal(
    self,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    past_key_values,
    labels: torch.Tensor,
    images: torch.Tensor,
    image_sizes: Optional[List[Tuple[int, int]]] = None
) -> Tuple[None, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Replace IMAGE_TOKEN_INDEX in input_ids with actual visual embeddings.

    Returns:
        Tuple of (None, position_ids, attention_mask, past_key_values, input_embeds, labels)
        where input_embeds contains the fused visual-text embeddings.
    """

Import

# Model loaded via load_pretrained_model; generate is a method on the model instance
from llava.model.builder import load_pretrained_model

tokenizer, model, image_processor, context_len = load_pretrained_model(...)
output_ids = model.generate(...)

Inputs

Parameter Type Required Description
input_ids torch.Tensor Yes Tokenized prompt containing IMAGE_TOKEN_INDEX (-200) at image positions
images torch.Tensor Yes Preprocessed image tensor from process_images()
image_sizes List[Tuple[int,int]] For anyres Original image dimensions (width, height)
do_sample bool Yes True for temperature sampling, False for greedy
temperature float Yes Sampling temperature (0.0 for greedy)
top_p float No Nucleus sampling threshold (default: None)
num_beams int No Beam search width (default: 1)
max_new_tokens int Yes Maximum number of tokens to generate
use_cache bool No Use KV cache (default: True)

Outputs

Output Type Description
output_ids torch.Tensor Generated token IDs (shape: [batch, seq_len])

The output is decoded to text via:

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

Usage Example

from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.conversation import conv_templates
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from PIL import Image
import torch

# 1. Load model
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path="liuhaotian/llava-v1.5-13b",
    model_base=None,
    model_name="llava-v1.5-13b"
)

# 2. Preprocess image
image = Image.open("photo.jpg").convert("RGB")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)

# 3. Construct prompt
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nDescribe this image.")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

# 4. Tokenize
input_ids = tokenizer_image_token(
    prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt'
).unsqueeze(0).cuda()

# 5. Generate
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image.size],
        do_sample=True,
        temperature=0.2,
        top_p=None,
        num_beams=1,
        max_new_tokens=512,
        use_cache=True,
    )

# 6. Decode
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(output_text)

Description

model.generate() with the images parameter triggers the full multimodal generation pipeline:

Internal Flow

  1. Forward hook -- The overridden prepare_inputs_for_generation() method detects the images parameter and calls prepare_inputs_labels_for_multimodal().
  2. Visual encoding -- Images pass through the CLIP vision tower (self.get_model().get_vision_tower()(images)) producing patch features.
  3. Projection -- Patch features are projected via self.get_model().mm_projector(image_features) into the LLM embedding space.
  4. Embedding fusion -- input_ids is split at IMAGE_TOKEN_INDEX positions. Text segments are embedded via the LLM's embedding layer. Visual embeddings are inserted between text segments.
  5. Sequence construction -- The fused input_embeds tensor, along with updated attention_mask and position_ids, is passed to the LLM for autoregressive generation.
  6. Token generation -- The LLM generates tokens autoregressively until max_new_tokens is reached or an EOS token is produced.

Embedding Expansion

Each <image> token expands to 576 visual tokens (for 336x336 input with 14x14 patch size, yielding a 24x24 grid). For anyres mode, the number of visual tokens scales with the number of patches plus the global view.

Metadata

Field Value
Knowledge Sources Paper - Visual Instruction Tuning - https://arxiv.org/abs/2304.08485
Domains Multimodal_Inference, Text_Generation
Last Updated 2026-02-13 14:00 GMT

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment