Implementation:AUTOMATIC1111 Stable diffusion webui SdUnet Extension
| Knowledge Sources | |
|---|---|
| Domains | UNet, Extension System |
| Last Updated | 2025-05-15 00:00 GMT |
Overview
Provides an extension mechanism for replacing the built-in UNet diffusion model with alternative UNet implementations via a plugin-style registration system.
Description
This module defines the infrastructure for swapping the default UNet used during image generation with a custom one provided by extensions. It maintains a global registry of SdUnetOption objects discovered through callbacks, and manages the lifecycle of the currently active SdUnet instance. The list_unets() function populates available options via script callbacks. The get_unet_option() function resolves the selected UNet by label or automatically matches by checkpoint model name. The apply_unet() function handles activation and deactivation of UNet replacements, moving the original diffusion model to CPU when a custom UNet is active to save GPU memory. The create_unet_forward() function wraps the original UNet forward pass to dispatch through the custom UNet when one is active.
Usage
Use this module when developing extensions that provide alternative UNet implementations (e.g., TensorRT-optimized UNets) or when the system needs to select and apply UNet replacements during the generation pipeline.
Code Reference
Source Location
- Repository: AUTOMATIC1111_Stable_diffusion_webui
- File: modules/sd_unet.py
- Lines: 1-94
Signature
def list_unets() -> None
def get_unet_option(option=None) -> Optional[SdUnetOption]
def apply_unet(option=None) -> None
def create_unet_forward(original_forward) -> Callable
class SdUnetOption:
model_name: str = None
label: str = None
def create_unet(self) -> SdUnet: ...
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs) -> Tensor
def activate(self) -> None
def deactivate(self) -> None
Import
from modules.sd_unet import SdUnet, SdUnetOption, apply_unet, list_unets
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| option | str or None | No | UNet label or None to use the value from shared.opts.sd_unet. Used by get_unet_option and apply_unet. |
| x | Tensor | Yes | Input latent tensor for SdUnet.forward. |
| timesteps | Tensor | Yes | Diffusion timestep tensor for SdUnet.forward. |
| context | Tensor | Yes | Conditioning context tensor for SdUnet.forward. |
| original_forward | Callable | Yes | The original UNetModel.forward method to wrap via create_unet_forward. |
Outputs
| Name | Type | Description |
|---|---|---|
| unet_option | SdUnetOption or None | The resolved UNet option from get_unet_option. |
| forward_result | Tensor | Output latent tensor from SdUnet.forward or the wrapped forward pass. |
Usage Examples
from modules.sd_unet import SdUnet, SdUnetOption
class MyCustomUnetOption(SdUnetOption):
def __init__(self):
self.model_name = "my_model"
self.label = "My Custom UNet"
def create_unet(self):
return MyCustomUnet()
class MyCustomUnet(SdUnet):
def forward(self, x, timesteps, context, *args, **kwargs):
# Custom UNet inference logic
return result
def activate(self):
# Load custom model weights
pass
def deactivate(self):
# Release resources
pass