Implementation:Deepseek ai Janus AR Token Generation Loop
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Image_Generation, Autoregressive_Models |
| Last Updated | 2026-02-10 09:30 GMT |
Overview
Pattern for the autoregressive VQ token generation loop with classifier-free guidance, using the Janus model's gen_head and prepare_gen_img_embeds methods.
Description
This is a user-defined pattern that implements the token-by-token generation loop for autoregressive image generation. It uses three model components:
- language_model.model(): LLM backbone forward pass (without the LM head)
- gen_head(): Projects hidden states to VQ codebook logits (Linear→GELU→Linear at modeling_vlm.py:L36-51)
- prepare_gen_img_embeds(): Converts sampled VQ indices to embeddings for the next step (gen_aligner(gen_embed(ids)) at modeling_vlm.py:L262-263)
Usage
Implement this pattern after CFG input preparation. The loop runs for image_token_num_per_image steps (default 576), producing a tensor of VQ codebook indices.
Code Reference
Source Location
- Repository: Janus
- File: generation_inference.py
- Lines: L79-95 (generation loop)
- Supporting: janus/models/modeling_vlm.py:L36-51 (vision_head/gen_head), L262-263 (prepare_gen_img_embeds)
Pattern Implementation
generated_tokens = torch.zeros(
(parallel_size, image_token_num_per_image), dtype=torch.int
).cuda()
for i in range(image_token_num_per_image):
# 1. LLM forward pass (backbone only, no LM head)
outputs = mmgpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None,
)
hidden_states = outputs.last_hidden_state
# 2. Project to VQ codebook logits via gen_head
logits = mmgpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :] # even rows: conditional
logit_uncond = logits[1::2, :] # odd rows: unconditional
# 3. Apply CFG
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
# 4. Sample next token
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
# 5. Prepare embedding for next step
next_token = torch.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
).view(-1)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
Import
# Uses model instance methods — no separate import needed
# mmgpt.language_model.model(...)
# mmgpt.gen_head(...)
# mmgpt.prepare_gen_img_embeds(...)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs_embeds | torch.Tensor [parallel_size*2, seq_len, D] | Yes | CFG-paired prompt embeddings from CFG preparation |
| temperature | float | No | Sampling temperature (default 1.0) |
| cfg_weight | float | No | CFG guidance scale (default 5.0) |
| image_token_num_per_image | int | No | Number of VQ tokens per image (default 576) |
Outputs
| Name | Type | Description |
|---|---|---|
| generated_tokens | torch.LongTensor [parallel_size, 576] | VQ codebook indices for each generated image |
Usage Examples
Full Generation Loop
@torch.inference_mode()
def generate(mmgpt, vl_chat_processor, prompt,
temperature=1.0, parallel_size=16, cfg_weight=5,
image_token_num_per_image=576, img_size=384, patch_size=16):
input_ids = vl_chat_processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)
# CFG setup
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
for i in range(parallel_size*2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)
# Autoregressive generation loop
generated_tokens = torch.zeros(
(parallel_size, image_token_num_per_image), dtype=torch.int
).cuda()
for i in range(image_token_num_per_image):
outputs = mmgpt.language_model.model(
inputs_embeds=inputs_embeds, use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None
)
hidden_states = outputs.last_hidden_state
logits = mmgpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat(
[next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1
).view(-1)
img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
return generated_tokens
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment