Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepseek ai Janus AR Token Generation Loop

From Leeroopedia


Knowledge Sources
Domains Image_Generation, Autoregressive_Models
Last Updated 2026-02-10 09:30 GMT

Overview

Pattern for the autoregressive VQ token generation loop with classifier-free guidance, using the Janus model's gen_head and prepare_gen_img_embeds methods.

Description

This is a user-defined pattern that implements the token-by-token generation loop for autoregressive image generation. It uses three model components:

  • language_model.model(): LLM backbone forward pass (without the LM head)
  • gen_head(): Projects hidden states to VQ codebook logits (Linear→GELU→Linear at modeling_vlm.py:L36-51)
  • prepare_gen_img_embeds(): Converts sampled VQ indices to embeddings for the next step (gen_aligner(gen_embed(ids)) at modeling_vlm.py:L262-263)

Usage

Implement this pattern after CFG input preparation. The loop runs for image_token_num_per_image steps (default 576), producing a tensor of VQ codebook indices.

Code Reference

Source Location

  • Repository: Janus
  • File: generation_inference.py
  • Lines: L79-95 (generation loop)
  • Supporting: janus/models/modeling_vlm.py:L36-51 (vision_head/gen_head), L262-263 (prepare_gen_img_embeds)

Pattern Implementation

generated_tokens = torch.zeros(
    (parallel_size, image_token_num_per_image), dtype=torch.int
).cuda()

for i in range(image_token_num_per_image):
    # 1. LLM forward pass (backbone only, no LM head)
    outputs = mmgpt.language_model.model(
        inputs_embeds=inputs_embeds,
        use_cache=True,
        past_key_values=outputs.past_key_values if i != 0 else None,
    )
    hidden_states = outputs.last_hidden_state

    # 2. Project to VQ codebook logits via gen_head
    logits = mmgpt.gen_head(hidden_states[:, -1, :])
    logit_cond = logits[0::2, :]     # even rows: conditional
    logit_uncond = logits[1::2, :]   # odd rows: unconditional

    # 3. Apply CFG
    logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
    probs = torch.softmax(logits / temperature, dim=-1)

    # 4. Sample next token
    next_token = torch.multinomial(probs, num_samples=1)
    generated_tokens[:, i] = next_token.squeeze(dim=-1)

    # 5. Prepare embedding for next step
    next_token = torch.cat(
        [next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1
    ).view(-1)
    img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
    inputs_embeds = img_embeds.unsqueeze(dim=1)

Import

# Uses model instance methods — no separate import needed
# mmgpt.language_model.model(...)
# mmgpt.gen_head(...)
# mmgpt.prepare_gen_img_embeds(...)

I/O Contract

Inputs

Name Type Required Description
inputs_embeds torch.Tensor [parallel_size*2, seq_len, D] Yes CFG-paired prompt embeddings from CFG preparation
temperature float No Sampling temperature (default 1.0)
cfg_weight float No CFG guidance scale (default 5.0)
image_token_num_per_image int No Number of VQ tokens per image (default 576)

Outputs

Name Type Description
generated_tokens torch.LongTensor [parallel_size, 576] VQ codebook indices for each generated image

Usage Examples

Full Generation Loop

@torch.inference_mode()
def generate(mmgpt, vl_chat_processor, prompt,
             temperature=1.0, parallel_size=16, cfg_weight=5,
             image_token_num_per_image=576, img_size=384, patch_size=16):

    input_ids = vl_chat_processor.tokenizer.encode(prompt)
    input_ids = torch.LongTensor(input_ids)

    # CFG setup
    tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda()
    for i in range(parallel_size*2):
        tokens[i, :] = input_ids
        if i % 2 != 0:
            tokens[i, 1:-1] = vl_chat_processor.pad_id
    inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens)

    # Autoregressive generation loop
    generated_tokens = torch.zeros(
        (parallel_size, image_token_num_per_image), dtype=torch.int
    ).cuda()

    for i in range(image_token_num_per_image):
        outputs = mmgpt.language_model.model(
            inputs_embeds=inputs_embeds, use_cache=True,
            past_key_values=outputs.past_key_values if i != 0 else None
        )
        hidden_states = outputs.last_hidden_state
        logits = mmgpt.gen_head(hidden_states[:, -1, :])
        logit_cond = logits[0::2, :]
        logit_uncond = logits[1::2, :]
        logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(dim=-1)
        next_token = torch.cat(
            [next_token.unsqueeze(1), next_token.unsqueeze(1)], dim=1
        ).view(-1)
        img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
        inputs_embeds = img_embeds.unsqueeze(dim=1)

    return generated_tokens

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment