Implementation:AUTOMATIC1111 Stable diffusion webui SD3 Inferencer
| Knowledge Sources | |
|---|---|
| Domains | Stable_Diffusion_3, Inference |
| Last Updated | 2025-05-15 00:00 GMT |
Overview
Implements the Stable Diffusion 3 inference pipeline, wrapping the SD3 diffusion model, VAE, and text encoders into a unified interface compatible with the WebUI's model handling.
Description
The SD3 Inferencer module provides two classes: SD3Denoiser and SD3Inferencer. The SD3Denoiser extends k-diffusion's DiscreteSchedule to wrap the SD3 model's apply_model method for compatibility with k-diffusion sampling. The SD3Inferencer is a torch.nn.Module that assembles the complete SD3 pipeline: it initializes the base diffusion model with a configurable shift parameter, sets up the SD3 VAE for encoding/decoding with a 16-channel latent space, and configures the SD3 text conditioning system (SD3Cond). It computes alphas_cumprod from the model's sigma schedule for compatibility with existing sampling infrastructure. The class provides methods for first-stage encoding/decoding (with latent format processing), learned conditioning, denoiser creation, noise addition (using a flow-matching-style linear interpolation), dimension fixing (to multiples of 16), memory optimization field listings for medvram mode, and Diffusers weight mapping for joint transformer blocks.
Usage
Use this module when loading and running inference with Stable Diffusion 3 models. It provides the model interface expected by the WebUI's sampling and generation pipeline.
Code Reference
Source Location
- Repository: AUTOMATIC1111_Stable_diffusion_webui
- File: modules/models/sd3/sd3_model.py
- Lines: 1-96
Signature
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas) -> None
def forward(self, input, sigma, **kwargs) -> torch.Tensor
class SD3Inferencer(torch.nn.Module):
def __init__(self, state_dict, shift=3, use_ema=False) -> None
def cond_stage_model (property) -> SD3Cond
def before_load_weights(self, state_dict) -> None
def ema_scope(self) -> contextlib.nullcontext
def get_learned_conditioning(self, batch: list[str]) -> object
def apply_model(self, x, t, cond) -> torch.Tensor
def decode_first_stage(self, latent) -> torch.Tensor
def encode_first_stage(self, image) -> torch.Tensor
def get_first_stage_encoding(self, x) -> torch.Tensor
def create_denoiser(self) -> SD3Denoiser
def medvram_fields(self) -> list[tuple]
def add_noise_to_latent(self, x, noise, amount) -> torch.Tensor
def fix_dimensions(self, width, height) -> tuple[int, int]
def diffusers_weight_mapping(self) -> Generator[tuple[str, str]]
Import
from modules.models.sd3.sd3_model import SD3Inferencer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| state_dict | dict | Yes | Model state dictionary containing pretrained weights |
| shift | int | No | Noise schedule shift parameter (default 3) |
| use_ema | bool | No | Whether to use exponential moving average weights (default False) |
| x | torch.Tensor | Yes | Input latent tensor for apply_model |
| t | torch.Tensor | Yes | Timestep or sigma values |
| cond | dict | Yes | Conditioning dictionary with "crossattn" and "vector" keys |
| batch | list[str] | Yes | List of text prompts for get_learned_conditioning |
Outputs
| Name | Type | Description |
|---|---|---|
| prediction | torch.Tensor | Model prediction (noise or denoised output) from apply_model |
| decoded | torch.Tensor | Decoded image tensor from decode_first_stage |
| encoded | torch.Tensor | Encoded latent tensor from encode_first_stage |
| denoiser | SD3Denoiser | A denoiser instance compatible with k-diffusion sampling |
Usage Examples
from modules.models.sd3.sd3_model import SD3Inferencer
# Initialize with pretrained weights
model = SD3Inferencer(state_dict=weights, shift=3)
# Get text conditioning
cond = model.get_learned_conditioning(["a photo of a cat"])
# Create denoiser for sampling
denoiser = model.create_denoiser()
# Encode an image to latent space
latent = model.encode_first_stage(image_tensor)
# Decode latent back to image
image = model.decode_first_stage(latent)