Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers ControlNetModel Forward

From Leeroopedia
Property Value
Implementation Name ControlNetModel.forward
Type API Doc
Workflow ControlNet_Guided_Generation
Related Principle Huggingface_Diffusers_ControlNet_Residual_Injection
Source File src/diffusers/models/controlnets/controlnet.py
Lines L601-L803
Status Active
Implements Principle:Huggingface_Diffusers_ControlNet_Residual_Injection

API Signature

@apply_lora_scale("cross_attention_kwargs")
def forward(
    self,
    sample: torch.Tensor,
    timestep: torch.Tensor | float | int,
    encoder_hidden_states: torch.Tensor,
    controlnet_cond: torch.Tensor,
    conditioning_scale: float = 1.0,
    class_labels: torch.Tensor | None = None,
    timestep_cond: torch.Tensor | None = None,
    attention_mask: torch.Tensor | None = None,
    added_cond_kwargs: dict[str, torch.Tensor] | None = None,
    cross_attention_kwargs: dict[str, Any] | None = None,
    guess_mode: bool = False,
    return_dict: bool = True,
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:

Class: ControlNetModel

Import:

from diffusers import ControlNetModel
from diffusers.models.controlnets.controlnet import ControlNetOutput

Parameters

Parameter Type Default Description
sample torch.Tensor required Noisy latent tensor, shape (batch, 4, H/8, W/8).
timestep float | int required Current denoising timestep.
encoder_hidden_states torch.Tensor required Text encoder output, shape (batch, seq_len, hidden_dim).
controlnet_cond torch.Tensor required Prepared conditioning image tensor, shape (batch, 3, H, W).
conditioning_scale float 1.0 Multiplier for all output residuals.
class_labels None None Optional class labels for class-conditioned generation.
timestep_cond None None Additional timestep conditioning (e.g., guidance scale embedding).
attention_mask None None Attention mask for encoder hidden states.
added_cond_kwargs None None Additional conditioning for SDXL (text_embeds, time_ids).
cross_attention_kwargs None None Kwargs passed to attention processors (e.g., LoRA scale).
guess_mode bool False Enables logarithmic scale ramp across resolution levels.
return_dict bool True Whether to return ControlNetOutput or a plain tuple.

Return Value

Type Description
ControlNetOutput Dataclass with down_block_res_samples: tuple[torch.Tensor] and mid_block_res_sample: torch.Tensor.
tuple When return_dict=False: (down_block_res_samples, mid_block_res_sample).

ControlNetOutput Dataclass

@dataclass
class ControlNetOutput(BaseOutput):
    down_block_res_samples: tuple[torch.Tensor]
    mid_block_res_sample: torch.Tensor

Source: src/diffusers/models/controlnets/controlnet.py, lines 45-62.

I/O Contract

Direction Tensor Shape (SD 1.5, 512x512)
Input sample (noisy latent) (B, 4, 64, 64)
Input controlnet_cond (conditioning image) (B, 3, 512, 512)
Input encoder_hidden_states (text embeddings) (B, 77, 768)
Output down_block_res_samples[0-2] (B, 320, 64, 64)
Output down_block_res_samples[3-5] (B, 640, 32, 32)
Output down_block_res_samples[6-8] (B, 1280, 16, 16)
Output down_block_res_samples[9-10] (B, 1280, 8, 8)
Output down_block_res_samples[11] (B, 1280, 8, 8)
Output mid_block_res_sample (B, 1280, 8, 8)

Source Code Analysis

Channel Order Check

channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
    ...  # No-op, default
elif channel_order == "bgr":
    controlnet_cond = torch.flip(controlnet_cond, dims=[1])

Time Embedding

timesteps = timestep
if not torch.is_tensor(timesteps):
    dtype = torch.float32 if (is_mps or is_npu) else torch.float64
    timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
    timesteps = timesteps[None].to(sample.device)

timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)

Conditioning Injection

# Process through initial convolution
sample = self.conv_in(sample)

# Add conditioning embedding to sample
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond

Encoder Pass (Down Blocks)

down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
    if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
        sample, res_samples = downsample_block(
            hidden_states=sample, temb=emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
        )
    else:
        sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
    down_block_res_samples += res_samples

Mid Block

if self.mid_block is not None:
    if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
        sample = self.mid_block(
            sample, emb,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            cross_attention_kwargs=cross_attention_kwargs,
        )
    else:
        sample = self.mid_block(sample, emb)

Zero Convolution and Scaling

# Apply zero convolutions to down block residuals
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
    down_block_res_sample = controlnet_block(down_block_res_sample)
    controlnet_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples

# Apply zero convolution to mid block residual
mid_block_res_sample = self.controlnet_mid_block(sample)

# Scaling: guess mode uses logarithmic ramp
if guess_mode and not self.config.global_pool_conditions:
    scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)
    scales = scales * conditioning_scale
    down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
    mid_block_res_sample = mid_block_res_sample * scales[-1]
else:
    down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
    mid_block_res_sample = mid_block_res_sample * conditioning_scale

Source: src/diffusers/models/controlnets/controlnet.py, lines 770-803.

Global Pool Conditions (Optional)

if self.config.global_pool_conditions:
    down_block_res_samples = [
        torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
    ]
    mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)

When enabled, spatial features are globally averaged, discarding spatial information while retaining semantic conditioning.

Usage Example

import torch
from diffusers import ControlNetModel

controlnet = ControlNetModel.from_pretrained(
    "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
).to("cuda")

# Simulate inputs
sample = torch.randn(1, 4, 64, 64, dtype=torch.float16, device="cuda")
timestep = torch.tensor([500], device="cuda")
encoder_hidden_states = torch.randn(1, 77, 768, dtype=torch.float16, device="cuda")
controlnet_cond = torch.randn(1, 3, 512, 512, dtype=torch.float16, device="cuda")

# Forward pass
output = controlnet(
    sample=sample,
    timestep=timestep,
    encoder_hidden_states=encoder_hidden_states,
    controlnet_cond=controlnet_cond,
    conditioning_scale=1.0,
    guess_mode=False,
    return_dict=True,
)

print(f"Number of down block residuals: {len(output.down_block_res_samples)}")
# Output: Number of down block residuals: 12
print(f"Mid block residual shape: {output.mid_block_res_sample.shape}")
# Output: Mid block residual shape: torch.Size([1, 1280, 8, 8])

Notes

  • The @apply_lora_scale decorator handles LoRA weight scaling automatically from cross_attention_kwargs.
  • The forward method supports SDXL-style additional conditioning via added_cond_kwargs containing text_embeds and time_ids.
  • In guess mode, the logarithmic scale ramp produces 13 scale values (0.1 to 1.0) via torch.logspace(-1, 0, 13), one per residual plus the mid block.

Related Pages

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment