Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Sail sg LongSpec Triton Block Size Tuning

From Leeroopedia
Revision as of 10:41, 16 February 2026 by Admin (talk | contribs) (Auto-imported from heuristics/Sail_sg_LongSpec_Triton_Block_Size_Tuning.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains GPU_Kernels, Optimization, Attention_Mechanisms
Last Updated 2026-02-14 06:00 GMT

Overview

GPU architecture-specific block size and warp configuration for the Triton tree attention kernel, with optimized settings for A100 (SM 8.0) and RTX 3090 (SM 8.6) and a conservative fallback for unknown GPUs.

Description

The custom Triton tree attention kernel in LongSpec uses a hardware-aware configuration function (`get_fwd_config`) that selects optimal `BLOCK_M`, `BLOCK_N`, `num_stages`, and `num_warps` parameters based on the GPU's CUDA compute capability. These parameters control tile sizes for the attention computation, pipelining depth, and thread parallelism. The wrong configuration can lead to either poor performance (too conservative) or kernel launch failures (too aggressive).

Additionally, the kernel uses two micro-optimizations:

  • Dot-I trick: For head dimensions < 128, queries are multiplied by an identity matrix via `tl.dot(q, I)` to force placement in registers instead of shared memory.
  • exp2 instead of exp: Uses `tl.math.exp2` with log2(e) prescaling because CSE and LICM compiler optimizations do not work as expected with `exp` in loops.

Usage

Apply this heuristic when deploying LongSpec inference on different GPU architectures. If you are running on an unsupported GPU (not A100 or RTX 3090/4090), the kernel falls back to 32x32 blocks which may be 2-4x slower. The `get_fwd_config` function can be overridden at runtime with custom tuned parameters.

The Insight (Rule of Thumb)

  • Action: Select Triton block sizes based on GPU compute capability and problem dimensions.
  • Value:
    • A100 (SM 8.0), causal, D > 64, M > 1024: `BLOCK_M=128, BLOCK_N=128, num_stages=3, num_warps=8`
    • A100 (SM 8.0), causal, D > 64, M <= 1024: `BLOCK_M=128, BLOCK_N=32, num_stages=2, num_warps=4`
    • A100 (SM 8.0), causal, D <= 64: `BLOCK_M=128, BLOCK_N=64, num_stages=4, num_warps=4`
    • RTX 3090 (SM 8.6), causal, D <= 64: `BLOCK_M=64, BLOCK_N=64, num_stages=3, num_warps=4`
    • RTX 3090 (SM 8.6), causal, D > 64: `BLOCK_M=128, BLOCK_N=32, num_stages=2, num_warps=4`
    • Fallback (other GPUs): `BLOCK_M=32, BLOCK_N=32, num_stages=1, num_warps=4`
  • Trade-off: Larger blocks improve throughput on high-end GPUs but may cause register pressure or shared memory issues on smaller GPUs. The fallback is safe but significantly slower.

Reasoning

Different GPU architectures have different amounts of shared memory per SM, register files, and warp schedulers. A100 GPUs (SM 8.0) have 192KB shared memory per SM and can handle larger tile sizes with deeper pipelining (more stages). RTX 3090 GPUs (SM 8.6) have similar but slightly different characteristics that favor different block ratios. The conservative 32x32 fallback works universally but underutilizes hardware.

The `NOTE` comment in the source code explicitly states this function "can be overwritten at runtime to use your custom config", indicating the authors expect users on non-A100/3090 hardware to tune these parameters.

The exp2 optimization exploits the fact that NVIDIA GPUs have dedicated hardware for base-2 exponentials (`SFU.EX2`), making `exp2(x)` faster than `exp(x)`. Prescaling by `log2(e) = 1.4426950408889634` converts the natural exponential to base-2.

The Dot-I trick multiplies queries by an identity matrix to force the compiler to place them in registers. This is only applied for `BLOCK_DMODEL < 128` because at 128 the register pressure is already at the limit.

Code Evidence

Hardware-aware configuration from `triton_tree_attn.py:80-112`:

# NOTE: this function can be overwritten at runtime to use your custom config
def get_fwd_config(B, H, M, N, D, causal):
    if torch.cuda.get_device_capability() == (8, 0):
        if not causal:
            if D <= 64:
                BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 64, 3, 4
            else:
                if M <= 1024:
                    BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 32, 3, 4
                else:
                    BLOCK_M, BLOCK_N, num_stages, num_warps = 128, 128, 3, 8
        else:
            ...
    elif torch.cuda.get_device_capability() == (8, 6):
        ...
    else:
        BLOCK_M, BLOCK_N, num_stages, num_warps = 32, 32, 1, 4
    return (BLOCK_M, BLOCK_N, num_stages, num_warps)

exp2 optimization from `triton_tree_attn.py:138-142,218-219`:

# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
log2e: tl.constexpr = 1.4426950408889634
qk_scale = sm_scale * log2e
...
alpha = tl.math.exp2((m_i - m_i_new) * qk_scale)
p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale)

Dot-I trick from `triton_tree_attn.py:175-180`:

#Dot I trick: to place q in registers, it saves shared memory
if BLOCK_DMODEL < 128:
    I = tl.where(offs_k[:, None] == offs_k,
                 tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 1.0, dtype=input_dtype),
                 tl.full((BLOCK_DMODEL, BLOCK_DMODEL), 0.0, dtype=input_dtype))
    q = tl.dot(q, I).to(input_dtype)

Contiguity requirement from `triton_tree_attn.py:10-13`:

def maybe_contiguous(x):
    # only when the inner most dimension is contiguous can LDGSTS be used
    # so inner-dimension contiguity is enforced.
    return x.contiguous() if x.stride(-1) != 1 else x

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment