Implementation:OpenGVLab InternVL MPT Flash Attn Triton
| Knowledge Sources | |
|---|---|
| Domains | GPU Kernel, Flash Attention, Triton Compiler |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Triton-based FlashAttention implementation providing GPU-optimized fused forward and backward attention kernels with support for causal masking, attention bias (ALiBi), arbitrary sequence lengths, and head dimensions up to 128, ported from HazyResearch's flash-attention repository.
Description
This module provides a high-performance Triton-compiled FlashAttention implementation that fuses the entire attention computation (QK^T, softmax, V multiplication) into a single GPU kernel pass, avoiding materialization of the full attention matrix for significant memory savings.
Forward Kernel (_fwd_kernel): A Triton JIT-compiled kernel that implements the online softmax algorithm (log-sum-exp tracking) to compute attention in a single pass over key-value blocks. It uses heuristic-based specialization for even/odd sequence lengths and head dimensions to minimize boundary-checking overhead. Supports three attention bias types: none, vector (1D per-key bias for ALiBi), and matrix (2D query-key bias). The kernel operates on blocks of size BLOCK_M x BLOCK_N (typically 128x128) with configurable warps.
Backward Kernels: Two backward kernels implement the gradient computation: _bwd_preprocess_do_o_dot computes the delta values (dot product of output and grad-output), and _bwd_kernel_one_col_block computes dQ, dK, dV gradients with optional sequence-parallel mode using atomic adds for better parallelism on small batch sizes. The backward kernel uses Triton autotuning to select between sequential and parallel configurations.
Autograd Functions: Three torch.autograd.Function wrappers provide clean PyTorch integration:
- FlashAttnFunc: Separate Q, K, V tensors of shape (batch, seqlen, nheads, headdim)
- FlashAttnQKVPackedFunc: Packed QKV tensor of shape (batch, seqlen, 3, nheads, headdim)
- FlashAttnKVPackedFunc: Separate Q with packed KV for cross-attention patterns
Key constraints: Requires fp16/bf16 inputs on CUDA, head dimensions up to 128, no dropout support, and does not support ragged/nested tensors. The implementation stores only the log-sum-exp (LSE) values instead of full attention weights for backward computation.
Usage
Use this module as the Triton attention backend for the MPT model when attn_impl='triton' is configured. It is invoked through the triton_flash_attn_fn function in the MPT attention module. Requires triton_pre_mlir package.
Code Reference
Source Location
- Repository: OpenGVLab_InternVL
- File: internvl_chat_llava/llava/model/language_model/mpt/flash_attn_triton.py
- Lines: 1-484
Signature
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
"""Forward pass: returns (output, lse, softmax_scale)"""
...
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None,
causal=False, softmax_scale=None):
"""Backward pass: computes dq, dk, dv in-place"""
...
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
"""q, k, v: (batch_size, seqlen, nheads, headdim)"""
...
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
"""qkv: (batch, seqlen, 3, nheads, headdim)"""
...
class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
"""q: (batch, seqlen_q, nheads, headdim), kv: (batch, seqlen_k, 2, nheads, headdim)"""
...
# Convenience aliases
flash_attn_func = FlashAttnFunc.apply
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
Import
from llava.model.language_model.mpt.flash_attn_triton import flash_attn_func
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | torch.Tensor (fp16/bf16) | Yes | Query tensor of shape (batch, seqlen_q, nheads, headdim) |
| k | torch.Tensor (fp16/bf16) | Yes | Key tensor of shape (batch, seqlen_k, nheads, headdim) |
| v | torch.Tensor (fp16/bf16) | Yes | Value tensor of shape (batch, seqlen_k, nheads, headdim) |
| bias | torch.Tensor | No | Attention bias broadcastable to (batch, nheads, seqlen_q, seqlen_k) |
| causal | bool | No | Whether to apply causal (lower-triangular) masking |
| softmax_scale | float | No | Scaling factor for QK^T; defaults to 1/sqrt(headdim) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Attention output of shape (batch, seqlen_q, nheads, headdim) |
Usage Examples
Basic Usage
from llava.model.language_model.mpt.flash_attn_triton import flash_attn_func
# q, k, v: (batch=2, seqlen=512, nheads=16, headdim=64), dtype=torch.float16
output = flash_attn_func(q, k, v, causal=True)
# output shape: (2, 512, 16, 64)
# With ALiBi bias
alibi_bias = build_alibi_bias(n_heads=16, seq_len=512) # (1, 16, 1, 512)
output = flash_attn_func(q, k, v, bias=alibi_bias, causal=True)