Implementation:Deepseek ai Janus ODE Denoising Loop
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Image_Generation, Diffusion_Models |
| Last Updated | 2026-02-10 09:30 GMT |
Overview
Composite pattern and API calls implementing the rectified flow ODE denoising loop, using ShallowUViTEncoder, ShallowUViTDecoder, linear aligners, and the LLM backbone.
Description
This implementation combines several model components in an iterative loop:
- ShallowUViTEncoder.forward() — encodes noisy latent + timestep (janus/janusflow/models/uvit.py:L628-641)
- vision_gen_enc_aligner — Linear 768→2048 (janus/janusflow/models/modeling_vlm.py:L154)
- language_model.model() — LLM backbone forward pass with KV-caching
- vision_gen_dec_aligner_norm — LlamaRMSNorm (modeling_vlm.py:L166-168)
- vision_gen_dec_aligner — Linear 2048→768 (modeling_vlm.py:L169)
- ShallowUViTDecoder.forward() — decodes to velocity prediction (uvit.py:L702-714)
Usage
Implement this loop after noise initialization. KV-cache is used to avoid recomputing prompt tokens after step 0.
Code Reference
Source Location
- Repository: Janus
- File: demo/app_janusflow.py
- Lines: L95-131 (loop logic)
- Supporting: janus/janusflow/models/uvit.py:L572-714 (ShallowUViTEncoder/Decoder)
- Supporting: janus/janusflow/models/modeling_vlm.py:L154-169 (aligners)
Pattern Implementation
for step in range(num_inference_steps):
z_input = torch.cat([z, z], dim=0) # Duplicate for CFG
t = step / num_inference_steps * 1000.
t = torch.tensor([t] * z_input.shape[0]).to(dt)
# 1. Encode latent with ShallowUViTEncoder
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
# 2. Reshape and align to LLM dimension
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
# 3. Concatenate with text embeddings
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
# 4. LLM forward with KV-cache
if step == 0:
outputs = vl_gpt.language_model.model(
inputs_embeds=llm_emb, use_cache=True,
attention_mask=attention_mask, past_key_values=None
)
# Cache only prompt KVs for reuse
past_key_values = []
for kv_cache in outputs.past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((
k[:, :, :inputs_embeds.shape[1], :],
v[:, :, :inputs_embeds.shape[1], :]
))
past_key_values = tuple(past_key_values)
else:
outputs = vl_gpt.language_model.model(
inputs_embeds=llm_emb, use_cache=True,
attention_mask=attention_mask, past_key_values=past_key_values
)
hidden_states = outputs.last_hidden_state
# 5. Align back from LLM dimension and decode velocity
hidden_states = vl_gpt.vision_gen_dec_aligner(
vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :])
)
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
# 6. Apply CFG
v_cond, v_uncond = torch.chunk(v, 2)
v = cfg_weight * v_cond - (cfg_weight - 1.) * v_uncond
# 7. Euler step
z = z + dt * v
Import
# Uses model instance methods — no separate import needed
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| z | torch.Tensor [parallel_size, 4, 48, 48] | Yes | Initial noise from noise initialization |
| inputs_embeds | torch.Tensor [parallel_size*2, seq_len, D] | Yes | CFG-paired prompt embeddings |
| attention_mask | torch.IntTensor [parallel_size*2, total_len] | Yes | CFG attention mask |
| num_inference_steps | int | No | ODE steps (default 30) |
| cfg_weight | float | No | CFG scale (default 2.0) |
| dt | torch.Tensor | Yes | Euler step size |
Outputs
| Name | Type | Description |
|---|---|---|
| z | torch.Tensor [parallel_size, 4, 48, 48] | Denoised latent ready for VAE decoding |
Usage Examples
See the Pattern Implementation above for the complete loop. The key components are:
- ShallowUViTEncoder at uvit.py:L572-641: Conv + optional UVitBlock + timestep embedding
- ShallowUViTDecoder at uvit.py:L644-714: RMSNorm + skip connection + optional UVitBlock + Unpatchify
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment