Implementation:Bitsandbytes foundation Bitsandbytes Matmul Perf Model
| Knowledge Sources | |
|---|---|
| Domains | Performance_Modeling, Triton, Autotuning |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Analytical performance model for Triton matmul kernels that estimates execution time and prunes autotuning configurations based on hardware constraints and compute-memory overlap.
Description
This module provides an analytical performance model adapted from the Triton kernels repository. estimate_matmul_time calculates the expected execution time of a tiled matrix multiplication kernel by modeling three components: compute time (based on tensor core TFLOPS), memory loading time (modeling DRAM and L2 cache bandwidth with CTA occupancy), and store time. The total time is max(compute, load) + store. early_config_prune uses this model plus hardware constraints (shared memory limits, supported dtypes, optimal pipeline stages for Ampere+ GPUs) to reduce the Triton autotuning search space.
Usage
Used internally by the Triton INT8 matmul kernels to speed up autotuning by pruning unpromising configurations before benchmarking. It avoids testing configurations that would exceed shared memory or that have suboptimal pipeline depth.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/triton/matmul_perf_model.py
- Lines: 1-211
Signature
def estimate_matmul_time(
num_warps: int, num_stages: int,
A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
M: int, N: int, K: int,
BLOCK_M: int, BLOCK_N: int, BLOCK_K: int, SPLIT_K: int,
debug: bool = False, **kwargs
) -> float:
"""Return estimated kernel running time in milliseconds."""
def early_config_prune(configs, named_args, **kwargs) -> list:
"""Prune Triton autotuning configs based on hardware constraints."""
def get_tflops(device, num_ctas, num_warps, dtype) -> float:
"""Return compute throughput in TFLOPS for given occupancy."""
Import
from bitsandbytes.triton.matmul_perf_model import (
estimate_matmul_time,
early_config_prune,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| M, N, K | int | Yes | Matrix dimensions |
| BLOCK_M, BLOCK_N, BLOCK_K | int | Yes | Tiling block sizes |
| SPLIT_K | int | Yes | K-dimension split factor |
| num_warps | int | Yes | Warps per CTA |
| num_stages | int | Yes | Pipeline stages |
| A, B | torch.Tensor | Yes | Input tensors (for dtype and element_size) |
Outputs
| Name | Type | Description |
|---|---|---|
| time_ms | float | Estimated kernel execution time in milliseconds |
| pruned_configs | list | Filtered autotuning configurations |
Usage Examples
Estimating Matmul Time
from bitsandbytes.triton.matmul_perf_model import estimate_matmul_time
A = torch.randn(4096, 4096, device="cuda", dtype=torch.float16)
B = torch.randn(4096, 4096, device="cuda", dtype=torch.float16)
time_ms = estimate_matmul_time(
num_warps=4, num_stages=3,
A=A, B=B, C=None,
M=4096, N=4096, K=4096,
BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, SPLIT_K=1,
debug=True,
)