Implementation:OpenGVLab InternVL MPT Attention
| Knowledge Sources | |
|---|---|
| Domains | Attention Mechanism, Language Model, GPU Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Multi-head and multi-query attention implementations for the MPT language model, supporting three attention backends (torch, flash, triton), ALiBi positional biases, QK layer normalization, and QKV gradient clipping.
Description
This module implements the attention layer infrastructure for the MPT (MosaicML Pretrained Transformer) language model backend, providing multiple compute paths for different hardware and performance requirements:
Attention Functions:
- scaled_multihead_dot_product_attention: Standard PyTorch implementation using explicit matmul operations with support for causal masking, key padding masks, attention biases, and optional dropout. Reshapes inputs via einops for head-dimension manipulation.
- flash_attn_fn: Flash Attention v1 backend using
flash_attn_interface.flash_attn_unpadded_funcwith BERT-style padding/unpadding. Requires CUDA tensors in fp16/bf16. - triton_flash_attn_fn: Triton-based Flash Attention backend using the local
flash_attn_tritonmodule. Supports attention bias but not dropout or attention weight output.
Attention Modules:
- MultiheadAttention: Standard multi-head self-attention with a fused
Wqkvprojection producing 3 * d_model outputs, split into query, key, and value. Supports optional QK layer normalization and QKV clipping to stabilize training. - MultiQueryAttention: Multi-query attention variant where key and value share a single head (d_model + 2 * head_dim parameters), reducing KV cache memory. Keys and values are expanded to match query heads during computation.
Bias Infrastructure:
- build_alibi_bias: Constructs ALiBi (Attention with Linear Biases) position encoding tensors with configurable bias maximum.
- gen_slopes: Generates geometric slope values for ALiBi across heads.
- attn_bias_shape / build_attn_bias: Helper functions determining the required attention bias tensor shape based on configuration (alibi, prefix_lm, causal, sequence_id).
The module registers available attention classes in ATTN_CLASS_REGISTRY for dynamic selection.
Usage
Use this module as the attention component within the MPT language model architecture. Select the attention implementation via the attn_impl parameter ("torch", "flash", or "triton") and choose between multi-head and multi-query variants based on memory constraints.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/language_model/mpt/attention.py
- Lines: 1-300
Signature
def scaled_multihead_dot_product_attention(query, key, value, n_heads,
past_key_value=None, softmax_scale=None, attn_bias=None,
key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False,
needs_weights=False, multiquery=False):
...
class MultiheadAttention(nn.Module):
def __init__(self, d_model, n_heads, attn_impl='triton', clip_qkv=None,
qk_ln=False, softmax_scale=None, attn_pdrop=0.0,
low_precision_layernorm=False, verbose=0, device=None):
...
def forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True, needs_weights=False):
...
class MultiQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, attn_impl='triton', ...):
...
def forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True, needs_weights=False):
...
ATTN_CLASS_REGISTRY = {
'multihead_attention': MultiheadAttention,
'multiquery_attention': MultiQueryAttention,
}
Import
from llava.model.language_model.mpt.attention import ATTN_CLASS_REGISTRY
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Input tensor of shape (batch, seq_len, d_model) |
| past_key_value | Tuple[torch.Tensor] | No | Cached (key, value) tensors for autoregressive decoding |
| attn_bias | torch.Tensor | No | Additive attention bias (e.g., ALiBi) of shape (1, n_heads, seq, seq) |
| attention_mask | torch.Tensor | No | Key padding mask of shape (batch, seq_len) |
| is_causal | bool | No | Whether to apply causal masking (default True) |
| needs_weights | bool | No | Whether to return attention weights (not supported by triton backend) |
Outputs
| Name | Type | Description |
|---|---|---|
| context | torch.Tensor | Attention output projected through out_proj, shape (batch, seq_len, d_model) |
| attn_weights | torch.Tensor or None | Attention weight matrix (only when needs_weights=True and using torch backend) |
| past_key_value | Tuple[torch.Tensor] | Updated cached key-value tensors |
Usage Examples
Basic Usage
from llava.model.language_model.mpt.attention import MultiheadAttention
attn = MultiheadAttention(
d_model=2048,
n_heads=16,
attn_impl='torch', # or 'flash', 'triton'
clip_qkv=6.0,
qk_ln=True
)
# Forward pass
context, attn_weights, past_kv = attn(
x=hidden_states, # (batch, seq_len, 2048)
attn_bias=alibi_bias, # (1, 16, 1, seq_len) for ALiBi
attention_mask=mask, # (batch, seq_len)
is_causal=True
)