Implementation:OpenGVLab InternVL LLaMA Flash Attention Patch
| Knowledge Sources | |
|---|---|
| Domains | Flash Attention, Monkey Patching, LLaMA, Performance |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Monkey-patches LLaMA (v1/v3) attention with two alternative optimized implementations: a Flash Attention variant using packed QKV and a PyTorch native scaled_dot_product_attention variant.
Description
This module provides two forward function implementations and a dispatcher:
forward (Flash Attention path):
- Projects Q, K, V and transposes to (bsz, num_heads, seq_len, head_dim).
- Applies rotary position embeddings using the standard HuggingFace apply_rotary_pos_emb.
- Stacks Q, K, V into a packed QKV tensor of shape (bsz, seq_len, 3, num_heads, head_dim).
- For unmasked inputs: rearranges to contiguous format and calls flash_attn_unpadded_qkvpacked_func with cumulative sequence lengths.
- For masked inputs: uses unpad_input to remove padding, then calls flash attention on the unpadded tensor, and restores padding with pad_input.
- Compatible with both flash_attn v1 (flash_attn_unpadded_qkvpacked_func) and v2 (flash_attn_varlen_qkvpacked_func).
forward_2 (SDPA path):
- Uses PyTorch's built-in F.scaled_dot_product_attention during training with is_causal=True and dropout_p=0.0.
- During evaluation, falls back to manual matmul-based attention with explicit masking and fp32 softmax upcasting.
_prepare_decoder_attention_mask passes through the raw attention mask without transformation.
replace_llama_attn_with_flash_attn selects the implementation:
- Uses forward_2 (SDPA) if F.scaled_dot_product_attention is available (PyTorch 2.0+).
- Otherwise uses forward (Flash Attention) and replaces the mask preparation.
Usage
Call replace_llama_attn_with_flash_attn() before loading LLaMA v1/v3 models. The function automatically selects the best available attention backend based on the PyTorch version.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat/internvl/patch/llama_flash_attn_monkey_patch.py
- Lines: 1-222
Signature
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...
def forward_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ...
def replace_llama_attn_with_flash_attn(): ...
Import
from internvl.patch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input tensor of shape (batch_size, seq_len, hidden_size) |
| attention_mask | torch.Tensor | No | Key-padding mask (boolean for flash path, float for SDPA eval path) |
| position_ids | torch.Tensor | No | Position indices for rotary embeddings |
Outputs
| Name | Type | Description |
|---|---|---|
| attn_output | torch.Tensor | Attention output of shape (batch_size, seq_len, hidden_size) |
| attn_weights | Optional[torch.Tensor] | Attention weights (only in forward_2 eval mode, None otherwise) |
| past_key_value | Optional[Tuple] | Updated key/value cache (only forward_2 with use_cache, None for forward) |
Usage Examples
Basic Usage
from internvl.patch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
# Patch LLaMA attention before loading
replace_llama_attn_with_flash_attn()
# Load LLaMA model - will use optimized attention
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3-8b")