Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:AUTOMATIC1111 Stable diffusion webui Cross Attention Memory Slicing

From Leeroopedia




Knowledge Sources
Domains Optimization, Memory_Management
Last Updated 2026-02-08 08:00 GMT

Overview

Dynamic memory-aware attention slicing that calculates optimal chunk sizes based on available VRAM, using precision-dependent multipliers (3x for fp16, 2.5x for fp32) and a 3.3x safety factor to prevent out-of-memory errors during cross-attention computation.

Description

Cross-attention is the most memory-intensive operation in Stable Diffusion inference. The naive implementation requires materializing a full Q*K^T attention matrix of size [batch*heads, seq_len_q, seq_len_k], which scales quadratically with image resolution. The WebUI provides multiple sliced attention implementations that process the attention in chunks, trading a small amount of compute overhead for dramatically reduced peak memory usage. The Doggettx implementation dynamically calculates chunk sizes based on available VRAM, while other implementations use fixed heuristics.

Usage

This heuristic is applied automatically when using the Doggettx split attention optimizer (selected by default on CUDA when xformers is unavailable). It is most relevant when generating high-resolution images (768x768+) or using large batch sizes where the attention matrix exceeds available VRAM.

The Insight (Rule of Thumb)

  • Action: The attention memory estimator uses `tensor_size * modifier` where modifier is 3 for fp16 and 2.5 for fp32 to account for temporary tensors created during attention computation.
  • Value: Safety factor of 3.3x divides available VRAM to account for memory fragmentation and concurrent allocations. The system uses power-of-2 step counts for chunk sizing.
  • Trade-off: More slices = less memory but slower computation. If more than 64 slices would be needed, the system raises an error with a calculated maximum resolution.
  • MPS-specific: Uses a threshold of 2^16 (65536) elements to decide between full attention and sliced attention, and avoids slice sizes that are exact multiples of 4096 to prevent GPU memory allocation issues.

Reasoning

The attention matrix for a 512x512 image at the highest-resolution UNet layer has q.shape[1] = 4096 (64x64 spatial). For a 1024x1024 image, this becomes 16384, and the attention matrix grows to 16384^2 * element_size bytes. At fp16, this is ~512MB per head per batch element. The multiplier accounts for: (1) the Q*K^T intermediate, (2) the softmax result, (3) the S*V result, plus temporary allocations. The 3.3x safety factor was empirically determined to prevent fragmentation-induced OOM on a wide range of GPU configurations.

The MPS-specific avoidance of 4096-multiple slice sizes is a workaround for an Apple Metal memory allocator behavior that causes performance degradation or errors at these alignment boundaries.

Code Evidence

Dynamic memory calculation from `modules/sd_hijack_optimizations.py:245-261`:

mem_free_total = get_available_vram()

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1

if mem_required > mem_free_total:
    steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))

if steps > 64:
    max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
    raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
                       f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')

CUDA safety factor from `modules/sd_hijack_optimizations.py:337-345`:

def einsum_op_cuda(q, k, v):
    stats = torch.cuda.memory_stats(q.device)
    mem_active = stats['active_bytes.all.current']
    mem_reserved = stats['reserved_bytes.all.current']
    mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
    mem_free_torch = mem_reserved - mem_active
    mem_free_total = mem_free_cuda + mem_free_torch
    # Divide factor of safety as there's copying and fragmentation
    return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))

MPS 4096-alignment workaround from `modules/sd_hijack_optimizations.py:310-317`:

def einsum_op_mps_v1(q, k, v):
    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
        return einsum_op_compvis(q, k, v)
    else:
        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
        if slice_size % 4096 == 0:
            slice_size -= 1
        return einsum_op_slice_1(q, k, v, slice_size)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment