Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenGVLab InternVL MPT Flash Attn Triton

From Leeroopedia


Knowledge Sources
Domains GPU Kernel, Flash Attention, Triton Compiler
Last Updated 2026-02-07 14:00 GMT

Overview

Triton-based FlashAttention implementation providing GPU-optimized fused forward and backward attention kernels with support for causal masking, attention bias (ALiBi), arbitrary sequence lengths, and head dimensions up to 128, ported from HazyResearch's flash-attention repository.

Description

This module provides a high-performance Triton-compiled FlashAttention implementation that fuses the entire attention computation (QK^T, softmax, V multiplication) into a single GPU kernel pass, avoiding materialization of the full attention matrix for significant memory savings.

Forward Kernel (_fwd_kernel): A Triton JIT-compiled kernel that implements the online softmax algorithm (log-sum-exp tracking) to compute attention in a single pass over key-value blocks. It uses heuristic-based specialization for even/odd sequence lengths and head dimensions to minimize boundary-checking overhead. Supports three attention bias types: none, vector (1D per-key bias for ALiBi), and matrix (2D query-key bias). The kernel operates on blocks of size BLOCK_M x BLOCK_N (typically 128x128) with configurable warps.

Backward Kernels: Two backward kernels implement the gradient computation: _bwd_preprocess_do_o_dot computes the delta values (dot product of output and grad-output), and _bwd_kernel_one_col_block computes dQ, dK, dV gradients with optional sequence-parallel mode using atomic adds for better parallelism on small batch sizes. The backward kernel uses Triton autotuning to select between sequential and parallel configurations.

Autograd Functions: Three torch.autograd.Function wrappers provide clean PyTorch integration:

  • FlashAttnFunc: Separate Q, K, V tensors of shape (batch, seqlen, nheads, headdim)
  • FlashAttnQKVPackedFunc: Packed QKV tensor of shape (batch, seqlen, 3, nheads, headdim)
  • FlashAttnKVPackedFunc: Separate Q with packed KV for cross-attention patterns

Key constraints: Requires fp16/bf16 inputs on CUDA, head dimensions up to 128, no dropout support, and does not support ragged/nested tensors. The implementation stores only the log-sum-exp (LSE) values instead of full attention weights for backward computation.

Usage

Use this module as the Triton attention backend for the MPT model when attn_impl='triton' is configured. It is invoked through the triton_flash_attn_fn function in the MPT attention module. Requires triton_pre_mlir package.

Code Reference

Source Location

Signature

def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
    """Forward pass: returns (output, lse, softmax_scale)"""
    ...

def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None,
                          causal=False, softmax_scale=None):
    """Backward pass: computes dq, dk, dv in-place"""
    ...

class FlashAttnFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
        """q, k, v: (batch_size, seqlen, nheads, headdim)"""
        ...

class FlashAttnQKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
        """qkv: (batch, seqlen, 3, nheads, headdim)"""
        ...

class FlashAttnKVPackedFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
        """q: (batch, seqlen_q, nheads, headdim), kv: (batch, seqlen_k, 2, nheads, headdim)"""
        ...

# Convenience aliases
flash_attn_func = FlashAttnFunc.apply
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply

Import

from llava.model.language_model.mpt.flash_attn_triton import flash_attn_func

I/O Contract

Inputs

Name Type Required Description
q torch.Tensor (fp16/bf16) Yes Query tensor of shape (batch, seqlen_q, nheads, headdim)
k torch.Tensor (fp16/bf16) Yes Key tensor of shape (batch, seqlen_k, nheads, headdim)
v torch.Tensor (fp16/bf16) Yes Value tensor of shape (batch, seqlen_k, nheads, headdim)
bias torch.Tensor No Attention bias broadcastable to (batch, nheads, seqlen_q, seqlen_k)
causal bool No Whether to apply causal (lower-triangular) masking
softmax_scale float No Scaling factor for QK^T; defaults to 1/sqrt(headdim)

Outputs

Name Type Description
output torch.Tensor Attention output of shape (batch, seqlen_q, nheads, headdim)

Usage Examples

Basic Usage

from llava.model.language_model.mpt.flash_attn_triton import flash_attn_func

# q, k, v: (batch=2, seqlen=512, nheads=16, headdim=64), dtype=torch.float16
output = flash_attn_func(q, k, v, causal=True)
# output shape: (2, 512, 16, 64)

# With ALiBi bias
alibi_bias = build_alibi_bias(n_heads=16, seq_len=512)  # (1, 16, 1, 512)
output = flash_attn_func(q, k, v, bias=alibi_bias, causal=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment