Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Flash Attn Triton

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, Attention
Last Updated 2026-02-08 00:00 GMT

Overview

Triton implementation of the Flash Attention v2 algorithm providing fused, memory-efficient multi-head attention computation with support for causal masking, variable-length sequences, dropout, and grouped-query attention (GQA/MQA).

Description

This module implements Flash Attention v2 (Tri Dao, 2023) entirely in Triton, based on contributions from the OpenAI kernel team and AMD ML Frameworks Triton team. The algorithm computes softmax(Q * K^T / sqrt(d)) * V in a tiled, online fashion that avoids materializing the full attention matrix in GPU memory, reducing memory complexity from O(N^2) to O(N).

The implementation consists of several key components:

Core attention kernel (attn_fwd): The main forward kernel uses Triton's @triton.autotune decorator with 9 configurations varying BLOCK_M (16-256), BLOCK_N (16-128), waves_per_eu, and PRE_LOAD_V settings. The kernel operates in two phases: first processing full blocks (no masking needed), then masked blocks (causal or padding masking). It uses the numerically stable online softmax trick, computing 2^x instead of e^x by pre-scaling Q with sm_scale * log2(e).

Inner loop (_attn_fwd_inner): Processes blocks of K and V, computing QK dot products, applying causal and padding masks, computing softmax probabilities with online rescaling, optionally applying dropout, and accumulating the weighted V output.

Helper kernels: The module includes Triton JIT helper functions for dropout mask generation (dropout_offsets, dropout_rng, dropout_mask) using Philox RNG, a flexible tensor loader (load_fn) with configurable boundary checks, and utility functions (cdiv_fn, max_fn).

Variable-length support (VARLEN mode): When enabled, the kernel reads cumulative sequence length arrays (cu_seqlens_q, cu_seqlens_k) to handle packed sequences of different lengths without padding. It supports different Q and K sequence lengths and handles MQA/GQA through head group size computation (GROUP_SIZE = HQ // HK).

Python wrapper (_attention class): A torch.autograd.Function that validates inputs, computes strides and grid dimensions, pads head dimensions to the next power of 2 (minimum 16), and launches the Triton kernel. The public API is exposed as triton_attention.

Usage

This kernel is used for attention computation when a Triton-based Flash Attention implementation is preferred over CUDA-based alternatives (e.g., on AMD GPUs via ROCm). It supports the full inference pipeline with variable-length sequences, causal masking for autoregressive generation, and grouped-query attention patterns used in modern LLM architectures.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/flash_attn_triton.py
  • Lines: 1-791

Signature

@triton.autotune(configs=[...], key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"])
@triton.jit
def attn_fwd(
    Q, K, V, bias, sm_scale, L, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    stride_bz, stride_bh, stride_bm, stride_bn,
    cu_seqlens_q, cu_seqlens_k,
    dropout_p, philox_seed, philox_offset_base, encoded_softmax,
    HQ: tl.constexpr, HK: tl.constexpr,
    ACTUAL_BLOCK_DMODEL: tl.constexpr,
    MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
    VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
    PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
    ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr,
):

class _attention(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, o, cu_seqlens_q, cu_seqlens_k,
                max_seqlens_q, max_seqlens_k, causal=False,
                sm_scale=1.0, bias=None):

triton_attention = _attention.apply

Import

from lorax_server.utils.flash_attn_triton import triton_attention

I/O Contract

Inputs

Name Type Required Description
q torch.Tensor Yes Query tensor of shape (total_q, nheads_q, head_size) for variable-length mode. Must be contiguous.
k torch.Tensor Yes Key tensor of shape (total_k, nheads_k, head_size) for variable-length mode. Must match v shape.
v torch.Tensor Yes Value tensor of shape (total_k, nheads_k, head_size) for variable-length mode. Must match k shape.
o torch.Tensor Yes Output tensor, same shape as q. If None, allocated as empty_like(q).
cu_seqlens_q torch.Tensor Yes Cumulative sequence lengths for queries, shape (batch+1,). E.g., [0, 4, 10] for sequences of length 4 and 6.
cu_seqlens_k torch.Tensor Yes Cumulative sequence lengths for keys, shape (batch+1,). Must have same length as cu_seqlens_q.
max_seqlens_q int Yes Maximum query sequence length in the batch.
max_seqlens_k int Yes Maximum key sequence length in the batch.
causal bool No Enable causal (autoregressive) attention masking (default False).
sm_scale float No Softmax scaling factor, typically 1/sqrt(head_size) (default 1.0).
bias torch.Tensor No Optional attention bias tensor of shape (batch, nheads, seqlen_q, seqlen_k).

Outputs

Name Type Description
o torch.Tensor The attention output tensor of same shape as q, containing softmax(Q*K^T * sm_scale) * V.
encoded_softmax None Reserved for encoded softmax output (currently None in this configuration).

Usage Examples

from lorax_server.utils.flash_attn_triton import triton_attention

# Compute variable-length causal attention
output, _ = triton_attention(
    q,                    # (total_q, nheads_q, head_size)
    k,                    # (total_k, nheads_k, head_size)
    v,                    # (total_k, nheads_k, head_size)
    o,                    # (total_q, nheads_q, head_size), pre-allocated output
    cu_seqlens_q,         # (batch+1,)
    cu_seqlens_k,         # (batch+1,)
    max_seqlens_q,
    max_seqlens_k,
    causal=True,
    sm_scale=1.0 / (head_size ** 0.5),
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment