Implementation:Sail sg LongSpec Triton Tree Attn Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, Attention_Mechanisms |
| Last Updated | 2026-02-14 05:00 GMT |
Overview
Concrete tool for computing tree-masked attention using a custom Triton JIT kernel with hardware-aware block size tuning and optional log-sum-exp normalizer output.
Description
The attention() function provides a Python API for the Triton tree attention kernel. It computes scaled dot-product attention with an arbitrary tree-structured mask, supporting:
- Custom tree masks (batch_size, M, N) encoding tree parent-child relationships
- GQA support where num_kv_heads can differ from num_heads
- Log-sum-exp normalizer output for prefix attention merging
- Hardware-aware configs optimized for A100 (SM 8.0) and RTX 3090 (SM 8.6)
The underlying _fwd_kernel is a Triton JIT-compiled kernel that processes the attention in tiles, applying the tree mask and computing online softmax normalization.
Usage
Called internally by GlideAttention.triton_tree_part_fwd() during tree decoding. Not typically called directly by users.
Code Reference
Source Location
- Repository: LongSpec
- File: longspec/test/triton_tree_attn.py
- Lines: L19-270
Signature
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tree_mask: torch.Tensor,
sm_scale: Optional[float] = None,
return_log_normalizer: bool = True,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Compute tree-masked attention using custom Triton kernel.
Args:
q: Query tensor (batch_size, num_heads, M, head_dim)
k: Key tensor (batch_size, num_kv_heads, N, head_dim)
v: Value tensor (batch_size, num_kv_heads, N, head_dim)
tree_mask: Binary mask (batch_size, M, N). 1 = attend, 0 = mask out.
sm_scale: Softmax scaling factor (default: 1/sqrt(head_dim))
return_log_normalizer: If True, also returns LSE normalizers.
Returns:
If return_log_normalizer:
(output, log_normalizer) where:
output: (batch_size, num_heads, M, head_dim)
log_normalizer: (batch_size, num_heads, M)
Else:
output: (batch_size, num_heads, M, head_dim)
"""
Import
from longspec.test.triton_tree_attn import attention as tree_attention
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| q | torch.Tensor | Yes | Query tensor (B, H, M, D) where H=num_heads, M=num_tree_nodes, D=head_dim |
| k | torch.Tensor | Yes | Key tensor (B, H_kv, N, D) where H_kv=num_kv_heads, N=total_key_positions |
| v | torch.Tensor | Yes | Value tensor (B, H_kv, N, D) matching key dimensions |
| tree_mask | torch.Tensor | Yes | Binary mask (B, M, N) encoding tree structure (1=visible, 0=masked) |
| sm_scale | float | No | Softmax scale factor (default: 1/sqrt(head_dim)) |
| return_log_normalizer | bool | No | Whether to return LSE for prefix merging (default: True) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Attention output (B, H, M, D) - tree-masked attention over candidate tokens |
| log_normalizer | torch.Tensor | Log-sum-exp normalizers (B, H, M) for merging with prefix attention |
Usage Examples
Direct Kernel Call
import torch
from longspec.test.triton_tree_attn import attention as tree_attention
batch_size, num_heads, head_dim = 1, 32, 128
M = 69 # Number of tree nodes (4 + 16 + 16 + 16 + 16 + 1)
N = 69 # Same as M for tree-only attention
q = torch.randn(batch_size, num_heads, M, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(batch_size, num_heads, N, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(batch_size, num_heads, N, head_dim, device="cuda", dtype=torch.float16)
# Binary tree mask: tree_mask[i,j] = 1 if node j is ancestor of node i
tree_mask = torch.zeros(batch_size, M, N, device="cuda", dtype=torch.float16)
# ... fill tree_mask based on tree structure ...
output, log_normalizer = tree_attention(q, k, v, tree_mask)
# output: (1, 32, 69, 128)
# log_normalizer: (1, 32, 69)
Hardware-Aware Configuration
# The kernel auto-selects block sizes based on GPU:
# A100 (SM 8.0): BLOCK_M=128, BLOCK_N=64-128, num_warps=4-8
# RTX 3090 (SM 8.6): BLOCK_M=64-128, BLOCK_N=32-64, num_warps=4
# Other: BLOCK_M=32, BLOCK_N=32, num_warps=4 (conservative default)