Implementation:Predibase Lorax Flash Attn Triton
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, Attention |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton implementation of the Flash Attention v2 algorithm providing fused, memory-efficient multi-head attention computation with support for causal masking, variable-length sequences, dropout, and grouped-query attention (GQA/MQA).
Description
This module implements Flash Attention v2 (Tri Dao, 2023) entirely in Triton, based on contributions from the OpenAI kernel team and AMD ML Frameworks Triton team. The algorithm computes softmax(Q * K^T / sqrt(d)) * V in a tiled, online fashion that avoids materializing the full attention matrix in GPU memory, reducing memory complexity from O(N^2) to O(N).
The implementation consists of several key components:
Core attention kernel (attn_fwd): The main forward kernel uses Triton's @triton.autotune decorator with 9 configurations varying BLOCK_M (16-256), BLOCK_N (16-128), waves_per_eu, and PRE_LOAD_V settings. The kernel operates in two phases: first processing full blocks (no masking needed), then masked blocks (causal or padding masking). It uses the numerically stable online softmax trick, computing 2^x instead of e^x by pre-scaling Q with sm_scale * log2(e).
Inner loop (_attn_fwd_inner): Processes blocks of K and V, computing QK dot products, applying causal and padding masks, computing softmax probabilities with online rescaling, optionally applying dropout, and accumulating the weighted V output.
Helper kernels: The module includes Triton JIT helper functions for dropout mask generation (dropout_offsets, dropout_rng, dropout_mask) using Philox RNG, a flexible tensor loader (load_fn) with configurable boundary checks, and utility functions (cdiv_fn, max_fn).
Variable-length support (VARLEN mode): When enabled, the kernel reads cumulative sequence length arrays (cu_seqlens_q, cu_seqlens_k) to handle packed sequences of different lengths without padding. It supports different Q and K sequence lengths and handles MQA/GQA through head group size computation (GROUP_SIZE = HQ // HK).
Python wrapper (_attention class): A torch.autograd.Function that validates inputs, computes strides and grid dimensions, pads head dimensions to the next power of 2 (minimum 16), and launches the Triton kernel. The public API is exposed as triton_attention.
Usage
This kernel is used for attention computation when a Triton-based Flash Attention implementation is preferred over CUDA-based alternatives (e.g., on AMD GPUs via ROCm). It supports the full inference pipeline with variable-length sequences, causal masking for autoregressive generation, and grouped-query attention patterns used in modern LLM architectures.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/flash_attn_triton.py - Lines: 1-791
Signature
@triton.autotune(configs=[...], key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"])
@triton.jit
def attn_fwd(
Q, K, V, bias, sm_scale, L, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bn,
cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, encoded_softmax,
HQ: tl.constexpr, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k,
max_seqlens_q, max_seqlens_k, causal=False,
sm_scale=1.0, bias=None):
triton_attention = _attention.apply
Import
from lorax_server.utils.flash_attn_triton import triton_attention
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | torch.Tensor | Yes | Query tensor of shape (total_q, nheads_q, head_size) for variable-length mode. Must be contiguous. |
| k | torch.Tensor | Yes | Key tensor of shape (total_k, nheads_k, head_size) for variable-length mode. Must match v shape. |
| v | torch.Tensor | Yes | Value tensor of shape (total_k, nheads_k, head_size) for variable-length mode. Must match k shape. |
| o | torch.Tensor | Yes | Output tensor, same shape as q. If None, allocated as empty_like(q). |
| cu_seqlens_q | torch.Tensor | Yes | Cumulative sequence lengths for queries, shape (batch+1,). E.g., [0, 4, 10] for sequences of length 4 and 6. |
| cu_seqlens_k | torch.Tensor | Yes | Cumulative sequence lengths for keys, shape (batch+1,). Must have same length as cu_seqlens_q. |
| max_seqlens_q | int | Yes | Maximum query sequence length in the batch. |
| max_seqlens_k | int | Yes | Maximum key sequence length in the batch. |
| causal | bool | No | Enable causal (autoregressive) attention masking (default False). |
| sm_scale | float | No | Softmax scaling factor, typically 1/sqrt(head_size) (default 1.0). |
| bias | torch.Tensor | No | Optional attention bias tensor of shape (batch, nheads, seqlen_q, seqlen_k). |
Outputs
| Name | Type | Description |
|---|---|---|
| o | torch.Tensor | The attention output tensor of same shape as q, containing softmax(Q*K^T * sm_scale) * V. |
| encoded_softmax | None | Reserved for encoded softmax output (currently None in this configuration). |
Usage Examples
from lorax_server.utils.flash_attn_triton import triton_attention
# Compute variable-length causal attention
output, _ = triton_attention(
q, # (total_q, nheads_q, head_size)
k, # (total_k, nheads_k, head_size)
v, # (total_k, nheads_k, head_size)
o, # (total_q, nheads_q, head_size), pre-allocated output
cu_seqlens_q, # (batch+1,)
cu_seqlens_k, # (batch+1,)
max_seqlens_q,
max_seqlens_k,
causal=True,
sm_scale=1.0 / (head_size ** 0.5),
)