Implementation:Deepseek ai Janus CFG Input Preparation AR
| Knowledge Sources | |
|---|---|
| Domains | Image_Generation, Guided_Generation |
| Last Updated | 2026-02-10 09:30 GMT |
Overview
Pattern for constructing paired conditional/unconditional input embeddings for classifier-free guidance in autoregressive image generation.
Description
This is a user-defined pattern (not a library API) that sets up the CFG input structure for the autoregressive image generation loop. It duplicates the tokenized prompt into parallel_size × 2 rows, masks content tokens in odd rows with the pad ID, and converts to embeddings via the language model's embedding layer.
Usage
Implement this pattern after formatting and tokenizing the generation prompt, and before entering the autoregressive token generation loop.
Code Reference
Source Location
- Repository: Janus
- File: generation_inference.py
- Lines: L66-75
Pattern Implementation
# Reference pattern from generation_inference.py
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
# Duplicate for CFG: parallel_size * 2 rows
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
# Odd rows: mask content tokens with pad_id (unconditional)
tokens[i, 1:-1] = vl_chat_processor.pad_id
# Convert to embeddings
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
Import
# No specific import — uses existing model and processor instances
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| prompt | str | Yes | Formatted prompt string with image_start_tag appended |
| parallel_size | int | Yes | Number of images to generate (e.g., 16) |
| pad_id | int | Yes | Pad token ID from vl_chat_processor.pad_id |
Outputs
| Name | Type | Description |
|---|---|---|
| inputs_embeds | torch.Tensor [parallel_size*2, seq_len, D] | Paired conditional/unconditional embeddings |
Usage Examples
Standard CFG Setup
parallel_size = 16
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
# Shape: [32, seq_len, 2048]
# Even rows: full prompt embeddings
# Odd rows: pad-masked embeddings