Implementation:AUTOMATIC1111 Stable diffusion webui Sub Quadratic Attention
| Knowledge Sources | |
|---|---|
| Domains | Attention Mechanism, Memory Optimization |
| Last Updated | 2025-05-15 00:00 GMT |
Overview
Implements a memory-efficient sub-quadratic attention algorithm based on the paper "Self-attention Does Not Need O(n^2) Memory" that reduces memory usage to O(sqrt(n)) by chunking queries and key-value pairs.
Description
This module provides an alternative to standard scaled dot-product attention that avoids materializing the full attention matrix in memory. Based on the algorithm from https://arxiv.org/abs/2112.05682v2, it chunks the query, key, and value tensors into smaller pieces and processes them iteratively to compute attention with significantly reduced peak memory consumption.
The core function efficient_dot_product_attention() accepts query, key, and value tensors shaped as [batch * num_heads, tokens, channels_per_head] and optional chunk size parameters. It divides the computation into query chunks and key-value chunks, computing partial attention results for each chunk via _summarize_chunk(). The partial results are accumulated using a numerically stable log-sum-exp trick in _query_chunk_attention(), where per-chunk max scores are tracked and used to rescale exponential weights before combining. When chunk sizes are large enough to cover all tokens, the module falls back to a simpler _get_attention_scores_no_kv_chunking() path. The narrow_trunc() helper ensures safe slicing at tensor boundaries. Gradient checkpointing is optionally applied to each chunk summarization for further memory savings during training.
Usage
Use this attention implementation as a drop-in replacement for standard attention when GPU memory is limited, particularly for high-resolution image generation where the token count produces prohibitively large attention matrices.
Code Reference
Source Location
- Repository: AUTOMATIC1111_Stable_diffusion_webui
- File: modules/sub_quadratic_attention.py
- Lines: 1-215
Signature
def narrow_trunc(input: Tensor, dim: int, start: int, length: int) -> Tensor
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
query_chunk_size: int = 1024,
kv_chunk_size: Optional[int] = None,
kv_chunk_size_min: Optional[int] = None,
use_checkpoint: bool = True,
) -> Tensor
Import
from modules.sub_quadratic_attention import efficient_dot_product_attention
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| query | Tensor | Yes | Query tensor of shape [batch * num_heads, tokens, channels_per_head]. |
| key | Tensor | Yes | Key tensor of shape [batch * num_heads, tokens, channels_per_head]. |
| value | Tensor | Yes | Value tensor of shape [batch * num_heads, tokens, channels_per_head]. |
| query_chunk_size | int | No | Size of query chunks; defaults to 1024. |
| kv_chunk_size | int or None | No | Size of key/value chunks; defaults to sqrt(key_tokens) if None. |
| kv_chunk_size_min | int or None | No | Minimum key/value chunk size; ensures chunks are not too small. |
| use_checkpoint | bool | No | Whether to use gradient checkpointing; defaults to True. |
Outputs
| Name | Type | Description |
|---|---|---|
| result | Tensor | Output tensor of shape [batch * num_heads, query_tokens, channels_per_head]. |
Usage Examples
import torch
from modules.sub_quadratic_attention import efficient_dot_product_attention
# Typical attention computation with reduced memory
batch_heads = 8
tokens = 4096
channels = 64
query = torch.randn(batch_heads, tokens, channels, device="cuda")
key = torch.randn(batch_heads, tokens, channels, device="cuda")
value = torch.randn(batch_heads, tokens, channels, device="cuda")
output = efficient_dot_product_attention(
query, key, value,
query_chunk_size=1024,
kv_chunk_size=512,
use_checkpoint=False,
)