Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:AUTOMATIC1111 Stable diffusion webui Sub Quadratic Attention

From Leeroopedia


Knowledge Sources
Domains Attention Mechanism, Memory Optimization
Last Updated 2025-05-15 00:00 GMT

Overview

Implements a memory-efficient sub-quadratic attention algorithm based on the paper "Self-attention Does Not Need O(n^2) Memory" that reduces memory usage to O(sqrt(n)) by chunking queries and key-value pairs.

Description

This module provides an alternative to standard scaled dot-product attention that avoids materializing the full attention matrix in memory. Based on the algorithm from https://arxiv.org/abs/2112.05682v2, it chunks the query, key, and value tensors into smaller pieces and processes them iteratively to compute attention with significantly reduced peak memory consumption.

The core function efficient_dot_product_attention() accepts query, key, and value tensors shaped as [batch * num_heads, tokens, channels_per_head] and optional chunk size parameters. It divides the computation into query chunks and key-value chunks, computing partial attention results for each chunk via _summarize_chunk(). The partial results are accumulated using a numerically stable log-sum-exp trick in _query_chunk_attention(), where per-chunk max scores are tracked and used to rescale exponential weights before combining. When chunk sizes are large enough to cover all tokens, the module falls back to a simpler _get_attention_scores_no_kv_chunking() path. The narrow_trunc() helper ensures safe slicing at tensor boundaries. Gradient checkpointing is optionally applied to each chunk summarization for further memory savings during training.

Usage

Use this attention implementation as a drop-in replacement for standard attention when GPU memory is limited, particularly for high-resolution image generation where the token count produces prohibitively large attention matrices.

Code Reference

Source Location

Signature

def narrow_trunc(input: Tensor, dim: int, start: int, length: int) -> Tensor

class AttnChunk(NamedTuple):
    exp_values: Tensor
    exp_weights_sum: Tensor
    max_score: Tensor

def efficient_dot_product_attention(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    query_chunk_size: int = 1024,
    kv_chunk_size: Optional[int] = None,
    kv_chunk_size_min: Optional[int] = None,
    use_checkpoint: bool = True,
) -> Tensor

Import

from modules.sub_quadratic_attention import efficient_dot_product_attention

I/O Contract

Inputs

Name Type Required Description
query Tensor Yes Query tensor of shape [batch * num_heads, tokens, channels_per_head].
key Tensor Yes Key tensor of shape [batch * num_heads, tokens, channels_per_head].
value Tensor Yes Value tensor of shape [batch * num_heads, tokens, channels_per_head].
query_chunk_size int No Size of query chunks; defaults to 1024.
kv_chunk_size int or None No Size of key/value chunks; defaults to sqrt(key_tokens) if None.
kv_chunk_size_min int or None No Minimum key/value chunk size; ensures chunks are not too small.
use_checkpoint bool No Whether to use gradient checkpointing; defaults to True.

Outputs

Name Type Description
result Tensor Output tensor of shape [batch * num_heads, query_tokens, channels_per_head].

Usage Examples

import torch
from modules.sub_quadratic_attention import efficient_dot_product_attention

# Typical attention computation with reduced memory
batch_heads = 8
tokens = 4096
channels = 64

query = torch.randn(batch_heads, tokens, channels, device="cuda")
key = torch.randn(batch_heads, tokens, channels, device="cuda")
value = torch.randn(batch_heads, tokens, channels, device="cuda")

output = efficient_dot_product_attention(
    query, key, value,
    query_chunk_size=1024,
    kv_chunk_size=512,
    use_checkpoint=False,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment