Implementation:AUTOMATIC1111 Stable diffusion webui Disable Initialization
| Knowledge Sources | |
|---|---|
| Domains | Model_Loading, Memory_Optimization |
| Last Updated | 2025-05-15 00:00 GMT |
Overview
Provides context managers that optimize model loading by disabling weight initialization, preventing unnecessary downloads, allocating parameters on the meta device, and loading state dicts with minimal memory overhead.
Description
The Disable Initialization module contains three context manager classes built on a shared ReplaceHelper base that manages attribute replacement and restoration. DisableInitialization prevents PyTorch's layer initialization functions (kaiming_uniform_, _no_grad_normal_, _no_grad_uniform_) from running during model instantiation, since weights will be loaded from a state dict anyway. It also optionally prevents CLIP and OpenCLIP from downloading model weights by patching their pretrained loading functions, and optimizes HuggingFace Transformers to prefer local cached files over network requests. InitializeOnMeta forces all Linear, Conv2d, and MultiheadAttention layers to allocate their parameters on the meta device (zero memory), and disables model.to() calls. LoadStateDictOnMeta is the companion context manager that patches load_state_dict and _load_from_state_dict to move parameters from meta device to the target device as they are loaded from the state dict, deleting entries from the source dict progressively to minimize peak memory usage. It also supports per-key dtype conversion via a weight mapping dictionary.
Usage
Use these context managers during model instantiation and weight loading to significantly reduce memory usage and loading time. DisableInitialization is used when creating model instances, InitializeOnMeta allocates parameters without memory, and LoadStateDictOnMeta streams weights efficiently during loading.
Code Reference
Source Location
- Repository: AUTOMATIC1111_Stable_diffusion_webui
- File: modules/sd_disable_initialization.py
- Lines: 1-232
Signature
class ReplaceHelper:
def __init__(self) -> None
def replace(self, obj, field: str, func) -> callable | None
def restore(self) -> None
class DisableInitialization(ReplaceHelper):
def __init__(self, disable_clip: bool = True) -> None
def __enter__(self) -> None
def __exit__(self, exc_type, exc_val, exc_tb) -> None
class InitializeOnMeta(ReplaceHelper):
def __enter__(self) -> None
def __exit__(self, exc_type, exc_val, exc_tb) -> None
class LoadStateDictOnMeta(ReplaceHelper):
def __init__(self, state_dict: dict, device, weight_dtype_conversion: dict = None) -> None
def get_weight_dtype(self, key: str) -> torch.dtype | None
def __enter__(self) -> None
def __exit__(self, exc_type, exc_val, exc_tb) -> None
Import
from modules import sd_disable_initialization
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| disable_clip | bool | No | Whether to also disable CLIP/OpenCLIP weight downloading (default True) |
| state_dict | dict | Yes | Model state dictionary to load from (used by LoadStateDictOnMeta) |
| device | str or torch.device | Yes | Target device for weight placement (used by LoadStateDictOnMeta) |
| weight_dtype_conversion | dict | No | Mapping of weight name prefixes to target dtypes (used by LoadStateDictOnMeta) |
Outputs
| Name | Type | Description |
|---|---|---|
| context | context manager | Each class acts as a context manager that applies and reverts patches on enter/exit |
Usage Examples
from modules import sd_disable_initialization
# Disable weight initialization during model creation
with sd_disable_initialization.DisableInitialization():
model = create_model_from_config(config)
# Allocate on meta device for zero-memory instantiation
with sd_disable_initialization.InitializeOnMeta():
model = instantiate_from_config(sd_config.model)
# Load state dict with minimal memory usage
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device="cuda"):
model.load_state_dict(state_dict, strict=False)