Principle:Mit han lab Llm awq Fused Attention Optimization
Overview
Kernel fusion technique that combines Q, K, V projections into a single quantized GEMM and uses specialized attention kernels for prefilling and decoding phases.
Description
Standard transformer attention involves separate Q, K, V linear projections followed by scaled dot-product attention. Fused attention optimization applies three key techniques:
- QKV Fusion: Merges three separate WQLinear layers (q_proj, k_proj, v_proj) into a single fused WQLinear that performs one GEMM instead of three
- FlashAttention for Prefilling: Uses FlashAttention for processing the full context during the prefilling phase
- FasterTransformer Decoding: Uses FasterTransformer-style CUDA kernels for single-token decoding with a pre-allocated KV cache
This significantly reduces kernel launch overhead and memory bandwidth requirements.
Usage
Applied to TinyChat models before running inference to maximize throughput.
Theoretical Basis
QKV fusion:
[Q;K;V] = W_qkv @ x
This performs one GEMM instead of three separate projections.
- FlashAttention provides O(N) memory prefilling
- FasterTransformer masked MHA provides O(1) decoding with pre-allocated KV cache
Related Pages
- Implementation:Mit_han_lab_Llm_awq_Make_quant_attn
- Heuristic:Mit_han_lab_Llm_awq_Kernel_Selection_Thresholds
Knowledge Sources
- Repo|llm-awq|https://github.com/mit-han-lab/llm-awq
- Paper|FlashAttention|https://arxiv.org/abs/2205.14135
Domains
- Inference
- Optimization