Implementation:Mlfoundations Open flamingo Flamingo generate
Overview
Concrete tool for generating text conditioned on interleaved images and text provided by the OpenFlamingo Flamingo class.
Description
The Flamingo.generate() method encodes vision inputs through CLIP + Perceiver, conditions the language model's cross-attention layers based on <image> token positions in the text, then delegates to HuggingFace's generate() method for autoregressive decoding. After generation, conditioned layers are cleared to avoid memory leaks. Supports beam search, sampling, and other HuggingFace generation strategies via **kwargs.
Usage
After preparing vision_x tensor and tokenized text with <image> placeholders, call model.generate() to produce text.
Code Reference
- Source
- Repository https://github.com/mlfoundations/open_flamingo, File:
open_flamingo/src/flamingo.pyLines L124-175
- Signature
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
) -> torch.Tensor
- Import
from open_flamingo import create_model_and_transforms(model is created via factory,generateis a method)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
vision_x |
torch.Tensor |
Yes | Vision input, shape (B, T_img, F, C, H, W) with F=1
|
lang_x |
torch.Tensor |
Yes | Language input token IDs, shape (B, T_txt)
|
attention_mask |
torch.Tensor |
No | Attention mask, shape (B, T_txt)
|
**kwargs |
dict |
No | HuggingFace generate kwargs: max_new_tokens, num_beams, temperature, etc.
|
Outputs
| Name | Type | Description |
|---|---|---|
output_ids |
torch.Tensor |
Token IDs tensor — lang_x with generated tokens appended, shape (B, T_txt + generated_length)
|
Usage Examples
The following example demonstrates few-shot image captioning with 2 demo images and 1 query image:
from open_flamingo import create_model_and_transforms
from PIL import Image
import torch
# Create model and get image processor / tokenizer
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1,
)
# Load demo images and query image
demo_image_1 = image_processor(Image.open("demo1.jpg")).unsqueeze(0)
demo_image_2 = image_processor(Image.open("demo2.jpg")).unsqueeze(0)
query_image = image_processor(Image.open("query.jpg")).unsqueeze(0)
# Stack into vision_x: shape (B, T_img, F, C, H, W)
# T_img = 3 images, F = 1 frame each
vision_x = torch.stack([demo_image_1, demo_image_2, query_image], dim=1).unsqueeze(2)
# Tokenize the interleaved prompt with <image> placeholders
prompt = (
"<image>A cat sitting on a windowsill.<|endofchunk|>"
"<image>A dog playing in the park.<|endofchunk|>"
"<image>"
)
lang_x = tokenizer(prompt, return_tensors="pt")
# Generate caption for the query image
output_ids = model.generate(
vision_x=vision_x,
lang_x=lang_x["input_ids"],
attention_mask=lang_x["attention_mask"],
max_new_tokens=20,
num_beams=3,
)
# Decode the generated tokens
generated_text = tokenizer.decode(output_ids[0, lang_x["input_ids"].shape[1]:])
print(generated_text)
# Example output: "A bird perched on a tree branch."