Principle:AUTOMATIC1111 Stable diffusion webui Weight merging execution
| Knowledge Sources | |
|---|---|
| Domains | Model Merging, Memory Management, Checkpoint Merging |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Weight merging execution is the orchestrated process of loading model state dictionaries, iterating over their keys, applying per-tensor interpolation functions, and managing memory throughout the multi-pass merge pipeline.
Description
Merging large neural network checkpoints (typically 2-7 GB each) requires careful orchestration to avoid exhausting system memory while ensuring correct application of interpolation functions across all weight tensors. Weight merging execution encompasses the full pipeline from loading models to producing the final merged state dictionary.
The execution follows a two-pass strategy:
Pass 1 (Difference computation): For the "Add difference" method, models B and C are loaded first. The pipeline iterates over all keys in B's state dictionary: for each key containing "model" in its name and present in C, the difference (B - C) is computed and stored in-place, overwriting B's tensor. C's state dictionary is then deleted to free memory. This pass is skipped entirely for "Weighted sum" since it only requires two models.
Pass 2 (Final interpolation): Model A is loaded, and the pipeline iterates over all keys in A's state dictionary. For each key that also exists in the (possibly modified) B dictionary and contains "model" in its name, the interpolation function is applied. Special handling exists for shape mismatches between inpainting/instruct-pix2pix models and standard models, where only the overlapping channels are merged.
Key filtering: Certain keys are always skipped during merging (e.g., position_ids embeddings that are integer-typed and should not be interpolated). Keys that do not contain "model" in their name are also preserved as-is from model A.
Usage
Use this execution pattern when:
- Merging full-size models: The sequential loading and in-place computation pattern minimizes peak memory usage to approximately 2x the size of a single model rather than 3x.
- Handling heterogeneous model architectures: The per-key shape checking handles merges between standard, inpainting, and instruct-pix2pix model variants.
- Providing user feedback: The iteration pattern integrates with progress reporting (tqdm, shared.state) to display merge progress in the UI.
- Maintaining correctness: The key filtering and skip-on-merge list prevent corruption of non-weight tensors.
Theoretical Basis
Two-Pass Memory Optimization
Given models A, B, C each of size S bytes, a naive approach would require 3S memory simultaneously. The two-pass strategy reduces peak usage:
Pass 1: Load B (S), Load C (S) -> peak = 2S
Compute B := B - C in-place
Delete C -> memory = S (modified B holding differences)
Pass 2: Load A (S) -> peak = 2S
Compute A[k] := f(A[k], B[k], alpha) in-place per key
Delete B -> memory = S (final merged A)
Peak memory is 2S rather than 3S, saving approximately 2-7 GB for typical Stable Diffusion models.
Per-Key Iteration Pattern
The merge iterates over the state dictionary as an ordered mapping of string keys to tensors:
for key in state_dict_A:
if key in skip_list:
continue
if "model" not in key:
continue # preserve non-model keys as-is
if key in state_dict_B:
if shapes_match(A[key], B[key]):
A[key] = interpolate(A[key], B[key], alpha)
else:
handle_shape_mismatch(A, B, key, alpha)
A[key] = maybe_convert_to_half(A[key])
This ensures:
- Completeness: Every key in A is visited.
- Safety: Non-model keys (e.g., configuration tensors) are preserved.
- Flexibility: Shape mismatches are handled gracefully for special model types.
Shape Mismatch Handling
Inpainting models have 9 input channels (4 latent + 4 unmasked latent + 1 mask) versus the standard 4 channels. Instruct-pix2pix models have 8 channels. When merging these with standard models:
if A.shape[1] == 9 and B.shape[1] == 4: # inpainting + standard
A[:, 0:4, :, :] = interpolate(A[:, 0:4, :, :], B, alpha)
# channels 4-8 (unmasked latent + mask) kept from A
if A.shape[1] == 8 and B.shape[1] == 4: # instruct-pix2pix + standard
A[:, 0:4, :, :] = interpolate(A[:, 0:4, :, :], B, alpha)
# channels 4-7 kept from A
The constraint is that model A must be the one with more channels; the merge will raise a RuntimeError if B has more channels than A.
Progress Reporting
The iteration naturally integrates with progress tracking:
state.sampling_steps = len(state_dict.keys())
for key in tqdm(state_dict.keys()):
... perform merge ...
state.sampling_step += 1
This enables both terminal-based (tqdm) and UI-based (shared.state) progress display.