Implementation:Predibase Lorax GPTQ Utils Custom Autotune
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a custom Triton kernel autotuner optimized for GPTQ matrix multiplication, using fewer benchmark repetitions and optional nearest-power-of-two key rounding to reduce tuning overhead.
Description
This module is adapted from the GPTQ-triton project and provides a modified version of Triton's built-in autotuner with several GPTQ-specific optimizations.
Autotuner (triton.KernelInterface): The main autotuner class that manages kernel configuration selection. Key differences from Triton's default autotuner include:
- Reduced benchmarking: Uses 40 repetitions instead of 100 for kernel timing, which provides a good accuracy-to-speed tradeoff based on empirical testing.
- Nearest power-of-two caching: When nearest_power_of_two is enabled, key arguments (M, N, K) are rounded to the nearest power of two before looking up cached configurations. This significantly reduces the number of unique tuning runs needed.
- Config pruning: Supports early config pruning via early_config_prune, performance model-based pruning via perf_model, and top-k selection via top_k. The prune_configs() method applies these in sequence.
- Reset hooks: Supports zeroing out specified tensors between configuration benchmark runs via reset_to_zero.
The run() method checks the cache for a previously tuned configuration, and if not found, benchmarks all pruned configs and selects the fastest. The _bench() method uses triton.testing.do_bench with quantile reporting (median, 20th, 80th percentiles).
autotune(): A decorator function that wraps a Triton JIT kernel with the Autotuner, matching the API of triton.autotune.
matmul248_kernel_config_pruner(): A config pruning function specific to the GPTQ matmul kernel. It shrinks BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K when the corresponding matrix dimensions (M, N, K) are smaller, and deduplicates resulting configurations.
Usage
This module is used internally by quant_linear.py to autotune the matmul_248_kernel Triton kernel. The autotune decorator is applied with eight different tile configurations and the matmul248_kernel_config_pruner to efficiently select the best configuration for the given matrix dimensions at runtime.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/gptq/custom_autotune.py - Lines: 1-252
Signature
class Autotuner(triton.KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero,
prune_configs_by: Dict = None, nearest_power_of_two: bool = False)
def _bench(self, *args, config, **meta)
def run(self, *args, **kwargs)
def prune_configs(self, kwargs)
def warmup(self, *args, **kwargs)
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False)
def matmul248_kernel_config_pruner(configs, nargs)
Import
from lorax_server.utils.gptq.custom_autotune import autotune, matmul248_kernel_config_pruner
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| configs | list[triton.Config] | Yes | List of Triton kernel configurations to evaluate |
| key | list[str] | Yes | Argument names whose values determine when to re-tune (e.g., ["M", "N", "K"]) |
| prune_configs_by | Dict | No | Dictionary with optional keys: early_config_prune, perf_model, top_k |
| reset_to_zero | list[str] | No | Argument names of tensors to zero before each benchmark run |
| nearest_power_of_two | bool | No | Whether to round key arguments to nearest power of two for cache lookup |
Outputs
| Name | Type | Description |
|---|---|---|
| Autotuner | Autotuner | A wrapped kernel interface that automatically selects the best configuration at runtime |
| pruned_configs | generator[triton.Config] | From matmul248_kernel_config_pruner: yields pruned configs with adjusted block sizes |
Usage Examples
# Internal usage in quant_linear.py
from lorax_server.utils.gptq import custom_autotune
@custom_autotune.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4, num_warps=4),
# ... more configs
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, ...):
...