Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sail sg LongSpec Triton Tree Attn Kernel

From Leeroopedia
Knowledge Sources
Domains GPU_Kernels, Attention_Mechanisms
Last Updated 2026-02-14 05:00 GMT

Overview

Concrete tool for computing tree-masked attention using a custom Triton JIT kernel with hardware-aware block size tuning and optional log-sum-exp normalizer output.

Description

The attention() function provides a Python API for the Triton tree attention kernel. It computes scaled dot-product attention with an arbitrary tree-structured mask, supporting:

  • Custom tree masks (batch_size, M, N) encoding tree parent-child relationships
  • GQA support where num_kv_heads can differ from num_heads
  • Log-sum-exp normalizer output for prefix attention merging
  • Hardware-aware configs optimized for A100 (SM 8.0) and RTX 3090 (SM 8.6)

The underlying _fwd_kernel is a Triton JIT-compiled kernel that processes the attention in tiles, applying the tree mask and computing online softmax normalization.

Usage

Called internally by GlideAttention.triton_tree_part_fwd() during tree decoding. Not typically called directly by users.

Code Reference

Source Location

  • Repository: LongSpec
  • File: longspec/test/triton_tree_attn.py
  • Lines: L19-270

Signature

def attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    tree_mask: torch.Tensor,
    sm_scale: Optional[float] = None,
    return_log_normalizer: bool = True,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
    """
    Compute tree-masked attention using custom Triton kernel.

    Args:
        q: Query tensor (batch_size, num_heads, M, head_dim)
        k: Key tensor (batch_size, num_kv_heads, N, head_dim)
        v: Value tensor (batch_size, num_kv_heads, N, head_dim)
        tree_mask: Binary mask (batch_size, M, N). 1 = attend, 0 = mask out.
        sm_scale: Softmax scaling factor (default: 1/sqrt(head_dim))
        return_log_normalizer: If True, also returns LSE normalizers.

    Returns:
        If return_log_normalizer:
            (output, log_normalizer) where:
                output: (batch_size, num_heads, M, head_dim)
                log_normalizer: (batch_size, num_heads, M)
        Else:
            output: (batch_size, num_heads, M, head_dim)
    """

Import

from longspec.test.triton_tree_attn import attention as tree_attention

I/O Contract

Inputs

Name Type Required Description
q torch.Tensor Yes Query tensor (B, H, M, D) where H=num_heads, M=num_tree_nodes, D=head_dim
k torch.Tensor Yes Key tensor (B, H_kv, N, D) where H_kv=num_kv_heads, N=total_key_positions
v torch.Tensor Yes Value tensor (B, H_kv, N, D) matching key dimensions
tree_mask torch.Tensor Yes Binary mask (B, M, N) encoding tree structure (1=visible, 0=masked)
sm_scale float No Softmax scale factor (default: 1/sqrt(head_dim))
return_log_normalizer bool No Whether to return LSE for prefix merging (default: True)

Outputs

Name Type Description
output torch.Tensor Attention output (B, H, M, D) - tree-masked attention over candidate tokens
log_normalizer torch.Tensor Log-sum-exp normalizers (B, H, M) for merging with prefix attention

Usage Examples

Direct Kernel Call

import torch
from longspec.test.triton_tree_attn import attention as tree_attention

batch_size, num_heads, head_dim = 1, 32, 128
M = 69   # Number of tree nodes (4 + 16 + 16 + 16 + 16 + 1)
N = 69   # Same as M for tree-only attention

q = torch.randn(batch_size, num_heads, M, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch_size, num_heads, N, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch_size, num_heads, N, head_dim, device="cuda", dtype=torch.float16)

# Binary tree mask: tree_mask[i,j] = 1 if node j is ancestor of node i
tree_mask = torch.zeros(batch_size, M, N, device="cuda", dtype=torch.float16)
# ... fill tree_mask based on tree structure ...

output, log_normalizer = tree_attention(q, k, v, tree_mask)
# output: (1, 32, 69, 128)
# log_normalizer: (1, 32, 69)

Hardware-Aware Configuration

# The kernel auto-selects block sizes based on GPU:
# A100 (SM 8.0): BLOCK_M=128, BLOCK_N=64-128, num_warps=4-8
# RTX 3090 (SM 8.6): BLOCK_M=64-128, BLOCK_N=32-64, num_warps=4
# Other: BLOCK_M=32, BLOCK_N=32, num_warps=4 (conservative default)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment