Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL LLaMA Flash Attention Patch

From Leeroopedia


Knowledge Sources
Domains Flash Attention, Monkey Patching, LLaMA, Performance
Last Updated 2026-02-07 14:00 GMT

Overview

Monkey-patches LLaMA (v1/v3) attention with two alternative optimized implementations: a Flash Attention variant using packed QKV and a PyTorch native scaled_dot_product_attention variant.

Description

This module provides two forward function implementations and a dispatcher:

forward (Flash Attention path):

  • Projects Q, K, V and transposes to (bsz, num_heads, seq_len, head_dim).
  • Applies rotary position embeddings using the standard HuggingFace apply_rotary_pos_emb.
  • Stacks Q, K, V into a packed QKV tensor of shape (bsz, seq_len, 3, num_heads, head_dim).
  • For unmasked inputs: rearranges to contiguous format and calls flash_attn_unpadded_qkvpacked_func with cumulative sequence lengths.
  • For masked inputs: uses unpad_input to remove padding, then calls flash attention on the unpadded tensor, and restores padding with pad_input.
  • Compatible with both flash_attn v1 (flash_attn_unpadded_qkvpacked_func) and v2 (flash_attn_varlen_qkvpacked_func).

forward_2 (SDPA path):

  • Uses PyTorch's built-in F.scaled_dot_product_attention during training with is_causal=True and dropout_p=0.0.
  • During evaluation, falls back to manual matmul-based attention with explicit masking and fp32 softmax upcasting.

_prepare_decoder_attention_mask passes through the raw attention mask without transformation.

replace_llama_attn_with_flash_attn selects the implementation:

  • Uses forward_2 (SDPA) if F.scaled_dot_product_attention is available (PyTorch 2.0+).
  • Otherwise uses forward (Flash Attention) and replaces the mask preparation.

Usage

Call replace_llama_attn_with_flash_attn() before loading LLaMA v1/v3 models. The function automatically selects the best available attention backend based on the PyTorch version.

Code Reference

Source Location

Signature

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...

def forward_2(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...

def replace_llama_attn_with_flash_attn(): ...

Import

from internvl.patch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

I/O Contract

Inputs

Name Type Required Description
hidden_states torch.Tensor Yes Input tensor of shape (batch_size, seq_len, hidden_size)
attention_mask torch.Tensor No Key-padding mask (boolean for flash path, float for SDPA eval path)
position_ids torch.Tensor No Position indices for rotary embeddings

Outputs

Name Type Description
attn_output torch.Tensor Attention output of shape (batch_size, seq_len, hidden_size)
attn_weights Optional[torch.Tensor] Attention weights (only in forward_2 eval mode, None otherwise)
past_key_value Optional[Tuple] Updated key/value cache (only forward_2 with use_cache, None for forward)

Usage Examples

Basic Usage

from internvl.patch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

# Patch LLaMA attention before loading
replace_llama_attn_with_flash_attn()

# Load LLaMA model - will use optimized attention
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment