Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL MPT Attention

From Leeroopedia


Knowledge Sources
Domains Attention Mechanism, Language Model, GPU Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Multi-head and multi-query attention implementations for the MPT language model, supporting three attention backends (torch, flash, triton), ALiBi positional biases, QK layer normalization, and QKV gradient clipping.

Description

This module implements the attention layer infrastructure for the MPT (MosaicML Pretrained Transformer) language model backend, providing multiple compute paths for different hardware and performance requirements:

Attention Functions:

  • scaled_multihead_dot_product_attention: Standard PyTorch implementation using explicit matmul operations with support for causal masking, key padding masks, attention biases, and optional dropout. Reshapes inputs via einops for head-dimension manipulation.
  • flash_attn_fn: Flash Attention v1 backend using flash_attn_interface.flash_attn_unpadded_func with BERT-style padding/unpadding. Requires CUDA tensors in fp16/bf16.
  • triton_flash_attn_fn: Triton-based Flash Attention backend using the local flash_attn_triton module. Supports attention bias but not dropout or attention weight output.

Attention Modules:

  • MultiheadAttention: Standard multi-head self-attention with a fused Wqkv projection producing 3 * d_model outputs, split into query, key, and value. Supports optional QK layer normalization and QKV clipping to stabilize training.
  • MultiQueryAttention: Multi-query attention variant where key and value share a single head (d_model + 2 * head_dim parameters), reducing KV cache memory. Keys and values are expanded to match query heads during computation.

Bias Infrastructure:

  • build_alibi_bias: Constructs ALiBi (Attention with Linear Biases) position encoding tensors with configurable bias maximum.
  • gen_slopes: Generates geometric slope values for ALiBi across heads.
  • attn_bias_shape / build_attn_bias: Helper functions determining the required attention bias tensor shape based on configuration (alibi, prefix_lm, causal, sequence_id).

The module registers available attention classes in ATTN_CLASS_REGISTRY for dynamic selection.

Usage

Use this module as the attention component within the MPT language model architecture. Select the attention implementation via the attn_impl parameter ("torch", "flash", or "triton") and choose between multi-head and multi-query variants based on memory constraints.

Code Reference

Source Location

Signature

def scaled_multihead_dot_product_attention(query, key, value, n_heads,
    past_key_value=None, softmax_scale=None, attn_bias=None,
    key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False,
    needs_weights=False, multiquery=False):
    ...

class MultiheadAttention(nn.Module):
    def __init__(self, d_model, n_heads, attn_impl='triton', clip_qkv=None,
                 qk_ln=False, softmax_scale=None, attn_pdrop=0.0,
                 low_precision_layernorm=False, verbose=0, device=None):
        ...
    def forward(self, x, past_key_value=None, attn_bias=None,
                attention_mask=None, is_causal=True, needs_weights=False):
        ...

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, attn_impl='triton', ...):
        ...
    def forward(self, x, past_key_value=None, attn_bias=None,
                attention_mask=None, is_causal=True, needs_weights=False):
        ...

ATTN_CLASS_REGISTRY = {
    'multihead_attention': MultiheadAttention,
    'multiquery_attention': MultiQueryAttention,
}

Import

from llava.model.language_model.mpt.attention import ATTN_CLASS_REGISTRY

I/O Contract

Inputs

Name Type Required Description
x torch.Tensor Yes Input tensor of shape (batch, seq_len, d_model)
past_key_value Tuple[torch.Tensor] No Cached (key, value) tensors for autoregressive decoding
attn_bias torch.Tensor No Additive attention bias (e.g., ALiBi) of shape (1, n_heads, seq, seq)
attention_mask torch.Tensor No Key padding mask of shape (batch, seq_len)
is_causal bool No Whether to apply causal masking (default True)
needs_weights bool No Whether to return attention weights (not supported by triton backend)

Outputs

Name Type Description
context torch.Tensor Attention output projected through out_proj, shape (batch, seq_len, d_model)
attn_weights torch.Tensor or None Attention weight matrix (only when needs_weights=True and using torch backend)
past_key_value Tuple[torch.Tensor] Updated cached key-value tensors

Usage Examples

Basic Usage

from llava.model.language_model.mpt.attention import MultiheadAttention

attn = MultiheadAttention(
    d_model=2048,
    n_heads=16,
    attn_impl='torch',  # or 'flash', 'triton'
    clip_qkv=6.0,
    qk_ln=True
)

# Forward pass
context, attn_weights, past_kv = attn(
    x=hidden_states,        # (batch, seq_len, 2048)
    attn_bias=alibi_bias,   # (1, 16, 1, seq_len) for ALiBi
    attention_mask=mask,     # (batch, seq_len)
    is_causal=True
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment