Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:OpenGVLab InternVL Flash Attention Patching

From Leeroopedia


Knowledge Sources
Domains Flash Attention, Performance Optimization, Monkey Patching, LLaMA
Last Updated 2026-02-07 14:00 GMT

Overview

Flash Attention Patching replaces the standard attention implementation in LLaMA-family language models with optimized Flash Attention or PyTorch SDPA kernels via monkey-patching, dramatically reducing memory usage and improving training/inference speed.

Description

Standard self-attention computes and materializes the full N x N attention matrix, requiring O(N^2) memory. For long sequences used in multimodal models (where image tokens significantly extend context length), this becomes a critical bottleneck.

Flash Attention (Dao et al., 2022) uses a tiling strategy to compute attention without materializing the full attention matrix, reducing memory complexity to O(N) while maintaining exact computation. The InternVL project applies Flash Attention to LLaMA models via monkey-patching:

  • Attention mask handling -- The standard HuggingFace causal float attention mask is replaced with a simple boolean key-padding mask, as Flash Attention implements causal masking internally.
  • Variable-length sequences -- For batches with padding, the unpad/pad pattern removes padding before attention and restores it afterward, avoiding wasted computation on padding tokens.
  • Packed QKV -- Q, K, V tensors are packed together for efficient memory access in the flash attention kernel.
  • Version compatibility -- Separate implementations handle flash_attn v1 and v2 API differences, and a fallback path uses PyTorch 2.0+ scaled_dot_product_attention when Flash Attention is unavailable.
  • GPU capability check -- Warns if the GPU lacks the required compute capability (Ampere/Hopper) for Flash Attention.

Two monkey-patch modules are provided:

  • LLaMA 2 patch -- Supports GQA via KV-packed flash attention with past_key_value caching for incremental decoding.
  • LLaMA v1/v3 patch -- Uses QKV-packed flash attention for training with a SDPA fallback for inference.

Usage

Call the appropriate replace_llama*_attn_with_flash_attn() function before loading any LLaMA model. This is typically done at the top of training or inference scripts.

Theoretical Basis

Flash Attention (Dao et al., 2022) exploits the GPU memory hierarchy by computing attention in tiles that fit in SRAM, avoiding the O(N^2) HBM reads/writes of standard attention. The algorithm is numerically equivalent to standard attention (exact, not approximate) but achieves 2-4x wall-clock speedup and significant memory savings. This enables training with longer context windows, which is essential for multimodal models where image patches consume a large portion of the sequence length.

The monkey-patching approach is used instead of modifying the HuggingFace Transformers library directly, allowing InternVL to benefit from Flash Attention while maintaining compatibility with upstream library updates.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment