Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:AUTOMATIC1111 Stable diffusion webui UNet Performance Patches

From Leeroopedia



Knowledge Sources
Domains Optimization, Performance
Last Updated 2026-02-08 08:00 GMT

Overview

Three monkey-patch optimizations to the UNet inference path: removing unnecessary contiguous() memory copies, creating timestep embeddings directly on GPU, and enabling dimension-flexible image generation (multiples of 8 instead of 64).

Description

The `sd_hijack_unet.py` module applies three performance-critical patches to the UNet and related components in both the LDM and SGM codebases. These patches are applied unconditionally at import time using the `CondFunc` conditional monkey-patching framework. The patches target common performance bottlenecks: unnecessary memory copies, CPU-to-GPU tensor transfers, and resolution inflexibility.

Usage

These patches are applied automatically at startup and affect all image generation (txt2img, img2img, hires fix). Users benefit without any configuration. The resolution flexibility patch (`TorchHijackForUnet`) is particularly useful for generating images at non-standard resolutions.

The Insight (Rule of Thumb)

  • Patch 1 - Remove contiguous() calls: The original `SpatialTransformer.forward` called `.contiguous()` after reshape/permute operations. These calls trigger `aten::copy_` memory operations that are unnecessary because the subsequent operations can work with non-contiguous memory. Removing them eliminates redundant memory copies.
  • Patch 2 - On-device timestep embeddings: The original `timestep_embedding` function created sinusoidal embeddings on CPU and then transferred them to GPU. The patched version creates them directly on the target device using `device=timesteps.device`, avoiding a CPU-to-GPU PCIe transfer per denoising step.
  • Patch 3 - Flexible resolution: The `TorchHijackForUnet` class wraps `torch.cat` to automatically interpolate tensors to matching spatial dimensions before concatenation. This allows generating images at any multiple of 8 (instead of the original 64).

Reasoning

Contiguous removal: Each `.contiguous()` call allocates a new tensor and copies data. In the SpatialTransformer, this happens at multiple points per transformer block, and the UNet has ~32 transformer blocks. The cumulative effect is significant: reduced memory bandwidth usage and fewer CUDA kernel launches.

Timestep on-device: Timestep embeddings are computed at every denoising step (typically 20-50 times per image). Each CPU->GPU transfer involves PCIe bandwidth and synchronization overhead. Creating embeddings directly on GPU eliminates this latency.

Resolution flexibility: Original Stable Diffusion requires dimensions to be multiples of 64 (due to the 8x downsampling in the VAE and skip connections in the UNet). The interpolation hack at the `torch.cat` level allows skip connections to work even when dimensions don't match exactly, relaxing the constraint to multiples of 8.

Code Evidence

Contiguous removal from `modules/sd_hijack_unet.py:81-102`:

# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.
# Prevents a lot of unnecessary aten::copy_ calls
def spatial_transformer_forward(_, self, x: torch.Tensor, context=None):
    if not isinstance(context, list):
        context = [context]
    b, c, h, w = x.shape
    x_in = x
    x = self.norm(x)
    if not self.use_linear:
        x = self.proj_in(x)
    x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
    if self.use_linear:
        x = self.proj_in(x)
    for i, block in enumerate(self.transformer_blocks):
        x = block(x, context=context[i])
    if self.use_linear:
        x = self.proj_out(x)
    x = x.view(b, h, w, c).permute(0, 3, 1, 2)
    if not self.use_linear:
        x = self.proj_out(x)
    return x + x_in

On-device timestep embedding from `modules/sd_hijack_unet.py:57-78`:

# Monkey patch to create timestep embed tensor on device, avoiding a block.
def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(
                start=0, end=half, dtype=torch.float32, device=timesteps.device
            ) / half
        )
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

Resolution flexibility from `modules/sd_hijack_unet.py:10-33`:

class TorchHijackForUnet:
    """
    This is torch, but with cat that resizes tensors to appropriate dimensions
    if they do not match; this makes it possible to create pictures with
    dimensions that are multiples of 8 rather than 64
    """
    def cat(self, tensors, *args, **kwargs):
        if len(tensors) == 2:
            a, b = tensors
            if a.shape[-2:] != b.shape[-2:]:
                a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
            tensors = (a, b)
        return torch.cat(tensors, *args, **kwargs)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment