Implementation:Huggingface Diffusers ControlNetModel Forward
Appearance
| Property | Value |
|---|---|
| Implementation Name | ControlNetModel.forward |
| Type | API Doc |
| Workflow | ControlNet_Guided_Generation |
| Related Principle | Huggingface_Diffusers_ControlNet_Residual_Injection |
| Source File | src/diffusers/models/controlnets/controlnet.py
|
| Lines | L601-L803 |
| Status | Active |
| Implements | Principle:Huggingface_Diffusers_ControlNet_Residual_Injection |
API Signature
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor | float | int,
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: torch.Tensor | None = None,
timestep_cond: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
added_cond_kwargs: dict[str, torch.Tensor] | None = None,
cross_attention_kwargs: dict[str, Any] | None = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> ControlNetOutput | tuple[tuple[torch.Tensor, ...], torch.Tensor]:
Class: ControlNetModel
Import:
from diffusers import ControlNetModel
from diffusers.models.controlnets.controlnet import ControlNetOutput
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
sample |
torch.Tensor |
required | Noisy latent tensor, shape (batch, 4, H/8, W/8).
|
timestep |
float | int | required | Current denoising timestep. |
encoder_hidden_states |
torch.Tensor |
required | Text encoder output, shape (batch, seq_len, hidden_dim).
|
controlnet_cond |
torch.Tensor |
required | Prepared conditioning image tensor, shape (batch, 3, H, W).
|
conditioning_scale |
float |
1.0 |
Multiplier for all output residuals. |
class_labels |
None | None |
Optional class labels for class-conditioned generation. |
timestep_cond |
None | None |
Additional timestep conditioning (e.g., guidance scale embedding). |
attention_mask |
None | None |
Attention mask for encoder hidden states. |
added_cond_kwargs |
None | None |
Additional conditioning for SDXL (text_embeds, time_ids). |
cross_attention_kwargs |
None | None |
Kwargs passed to attention processors (e.g., LoRA scale). |
guess_mode |
bool |
False |
Enables logarithmic scale ramp across resolution levels. |
return_dict |
bool |
True |
Whether to return ControlNetOutput or a plain tuple.
|
Return Value
| Type | Description |
|---|---|
ControlNetOutput |
Dataclass with down_block_res_samples: tuple[torch.Tensor] and mid_block_res_sample: torch.Tensor.
|
tuple |
When return_dict=False: (down_block_res_samples, mid_block_res_sample).
|
ControlNetOutput Dataclass
@dataclass
class ControlNetOutput(BaseOutput):
down_block_res_samples: tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
Source: src/diffusers/models/controlnets/controlnet.py, lines 45-62.
I/O Contract
| Direction | Tensor | Shape (SD 1.5, 512x512) |
|---|---|---|
| Input | sample (noisy latent) |
(B, 4, 64, 64)
|
| Input | controlnet_cond (conditioning image) |
(B, 3, 512, 512)
|
| Input | encoder_hidden_states (text embeddings) |
(B, 77, 768)
|
| Output | down_block_res_samples[0-2] |
(B, 320, 64, 64)
|
| Output | down_block_res_samples[3-5] |
(B, 640, 32, 32)
|
| Output | down_block_res_samples[6-8] |
(B, 1280, 16, 16)
|
| Output | down_block_res_samples[9-10] |
(B, 1280, 8, 8)
|
| Output | down_block_res_samples[11] |
(B, 1280, 8, 8)
|
| Output | mid_block_res_sample |
(B, 1280, 8, 8)
|
Source Code Analysis
Channel Order Check
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
... # No-op, default
elif channel_order == "bgr":
controlnet_cond = torch.flip(controlnet_cond, dims=[1])
Time Embedding
timesteps = timestep
if not torch.is_tensor(timesteps):
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
Conditioning Injection
# Process through initial convolution
sample = self.conv_in(sample)
# Add conditioning embedding to sample
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond
Encoder Pass (Down Blocks)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
Mid Block
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample, emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
Zero Convolution and Scaling
# Apply zero convolutions to down block residuals
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples += (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
# Apply zero convolution to mid block residual
mid_block_res_sample = self.controlnet_mid_block(sample)
# Scaling: guess mode uses logarithmic ramp
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1]
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
Source: src/diffusers/models/controlnets/controlnet.py, lines 770-803.
Global Pool Conditions (Optional)
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
When enabled, spatial features are globally averaged, discarding spatial information while retaining semantic conditioning.
Usage Example
import torch
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16
).to("cuda")
# Simulate inputs
sample = torch.randn(1, 4, 64, 64, dtype=torch.float16, device="cuda")
timestep = torch.tensor([500], device="cuda")
encoder_hidden_states = torch.randn(1, 77, 768, dtype=torch.float16, device="cuda")
controlnet_cond = torch.randn(1, 3, 512, 512, dtype=torch.float16, device="cuda")
# Forward pass
output = controlnet(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_cond,
conditioning_scale=1.0,
guess_mode=False,
return_dict=True,
)
print(f"Number of down block residuals: {len(output.down_block_res_samples)}")
# Output: Number of down block residuals: 12
print(f"Mid block residual shape: {output.mid_block_res_sample.shape}")
# Output: Mid block residual shape: torch.Size([1, 1280, 8, 8])
Notes
- The
@apply_lora_scaledecorator handles LoRA weight scaling automatically fromcross_attention_kwargs. - The forward method supports SDXL-style additional conditioning via
added_cond_kwargscontainingtext_embedsandtime_ids. - In guess mode, the logarithmic scale ramp produces 13 scale values (0.1 to 1.0) via
torch.logspace(-1, 0, 13), one per residual plus the mid block.
Related Pages
- Huggingface_Diffusers_ControlNet_Residual_Injection -- Principle: how residuals flow from ControlNet to UNet
- Huggingface_Diffusers_ControlNetModel_From_Pretrained -- Loading the model whose forward pass is documented here
- Huggingface_Diffusers_ControlNet_Pipeline_Call -- The pipeline that calls this forward method during denoising
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment