Heuristic:AUTOMATIC1111 Stable diffusion webui Cross Attention Memory Slicing
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
Dynamic memory-aware attention slicing that calculates optimal chunk sizes based on available VRAM, using precision-dependent multipliers (3x for fp16, 2.5x for fp32) and a 3.3x safety factor to prevent out-of-memory errors during cross-attention computation.
Description
Cross-attention is the most memory-intensive operation in Stable Diffusion inference. The naive implementation requires materializing a full Q*K^T attention matrix of size [batch*heads, seq_len_q, seq_len_k], which scales quadratically with image resolution. The WebUI provides multiple sliced attention implementations that process the attention in chunks, trading a small amount of compute overhead for dramatically reduced peak memory usage. The Doggettx implementation dynamically calculates chunk sizes based on available VRAM, while other implementations use fixed heuristics.
Usage
This heuristic is applied automatically when using the Doggettx split attention optimizer (selected by default on CUDA when xformers is unavailable). It is most relevant when generating high-resolution images (768x768+) or using large batch sizes where the attention matrix exceeds available VRAM.
The Insight (Rule of Thumb)
- Action: The attention memory estimator uses `tensor_size * modifier` where modifier is 3 for fp16 and 2.5 for fp32 to account for temporary tensors created during attention computation.
- Value: Safety factor of 3.3x divides available VRAM to account for memory fragmentation and concurrent allocations. The system uses power-of-2 step counts for chunk sizing.
- Trade-off: More slices = less memory but slower computation. If more than 64 slices would be needed, the system raises an error with a calculated maximum resolution.
- MPS-specific: Uses a threshold of 2^16 (65536) elements to decide between full attention and sliced attention, and avoids slice sizes that are exact multiples of 4096 to prevent GPU memory allocation issues.
Reasoning
The attention matrix for a 512x512 image at the highest-resolution UNet layer has q.shape[1] = 4096 (64x64 spatial). For a 1024x1024 image, this becomes 16384, and the attention matrix grows to 16384^2 * element_size bytes. At fp16, this is ~512MB per head per batch element. The multiplier accounts for: (1) the Q*K^T intermediate, (2) the softmax result, (3) the S*V result, plus temporary allocations. The 3.3x safety factor was empirically determined to prevent fragmentation-induced OOM on a wide range of GPU configurations.
The MPS-specific avoidance of 4096-multiple slice sizes is a workaround for an Apple Metal memory allocator behavior that causes performance degradation or errors at these alignment boundaries.
Code Evidence
Dynamic memory calculation from `modules/sd_hijack_optimizations.py:245-261`:
mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
CUDA safety factor from `modules/sd_hijack_optimizations.py:337-345`:
def einsum_op_cuda(q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
MPS 4096-alignment workaround from `modules/sd_hijack_optimizations.py:310-317`:
def einsum_op_mps_v1(q, k, v):
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
if slice_size % 4096 == 0:
slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)