Implementation:Predibase Lorax GPTQ Custom Autotune
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides a custom Triton autotuner for GPTQ quantized matrix multiplication kernels, with reduced benchmarking iterations and configuration pruning based on matrix dimensions.
Description
This module is adapted from the GPTQ-triton project and provides a modified version of the Triton autotuner optimized for quantized inference workloads. It contains:
Autotuner: A class extending triton.KernelInterface that manages kernel configuration selection. Key features include:
- Cache-based tuning: Stores the best configuration per key tuple (typically M, N, K dimensions), with an option to round keys to the nearest power of two (nearest_power_of_two) to reduce the number of unique tuning runs.
- Reduced benchmarking: Uses 40 repetitions instead of the default 100, which was found to give sufficiently accurate results while reducing tuning time.
- Configuration pruning: Supports three pruning strategies: an early_config_prune function applied before benchmarking, a perf_model for performance prediction, and top_k to limit the number of configs actually benchmarked.
- _bench method: Runs a single config benchmark using triton.testing.do_bench with quantile reporting (0.5, 0.2, 0.8). Returns infinity on OutOfResources errors.
- warmup method: Warms up all pruned configurations without benchmarking.
autotune: A decorator function that wraps a triton.jit kernel with the Autotuner class. It accepts configs (list of triton.Config), key (argument names that trigger re-evaluation), prune_configs_by (pruning strategy dict), reset_to_zero (tensor arguments to zero before each trial), and nearest_power_of_two.
matmul248_kernel_config_pruner: A generator function that shrinks BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K when the corresponding matrix dimensions (M, N, K) are smaller than the configured block sizes. It deduplicates configurations to avoid redundant benchmarking.
Usage
This module is used as a decorator on the matmul_248_kernel Triton kernel in quant_linear.py. It automatically tunes kernel launch parameters for each unique combination of matrix dimensions encountered during inference.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/gptq/custom_autotune.py
- Lines: 1-252
Signature
class Autotuner(triton.KernelInterface):
def __init__(self, fn, arg_names, configs, key, reset_to_zero,
prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None,
nearest_power_of_two=False):
def matmul248_kernel_config_pruner(configs, nargs):
Import
from lorax_server.layers.gptq.custom_autotune import autotune, matmul248_kernel_config_pruner
I/O Contract
Inputs (autotune decorator)
| Name | Type | Required | Description |
|---|---|---|---|
| configs | list[triton.Config] | Yes | List of Triton kernel configurations to evaluate |
| key | list[str] | Yes | Argument names whose value changes trigger re-evaluation of all configs |
| prune_configs_by | dict or None | No | Dict with 'early_config_prune', 'perf_model', and 'top_k' for config pruning |
| reset_to_zero | list[str] or None | No | Argument names whose tensors are zeroed before each trial |
| nearest_power_of_two | bool | No | Whether to round key values to nearest power of two for caching |
Outputs
| Name | Type | Description |
|---|---|---|
| decorator | Callable | A decorator that wraps a Triton JIT function with the Autotuner |
Usage Examples
# Used as a decorator on Triton GPTQ kernels
from lorax_server.layers.gptq import custom_autotune
@custom_autotune.autotune(
configs=[triton.Config({"BLOCK_SIZE_M": 64, ...}, num_stages=4, num_warps=4)],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={"early_config_prune": custom_autotune.matmul248_kernel_config_pruner},
)
@triton.jit
def my_kernel(...):
...