Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax GPTQ Utils Custom Autotune

From Leeroopedia


Knowledge Sources
Domains Quantization, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides a custom Triton kernel autotuner optimized for GPTQ matrix multiplication, using fewer benchmark repetitions and optional nearest-power-of-two key rounding to reduce tuning overhead.

Description

This module is adapted from the GPTQ-triton project and provides a modified version of Triton's built-in autotuner with several GPTQ-specific optimizations.

Autotuner (triton.KernelInterface): The main autotuner class that manages kernel configuration selection. Key differences from Triton's default autotuner include:

  • Reduced benchmarking: Uses 40 repetitions instead of 100 for kernel timing, which provides a good accuracy-to-speed tradeoff based on empirical testing.
  • Nearest power-of-two caching: When nearest_power_of_two is enabled, key arguments (M, N, K) are rounded to the nearest power of two before looking up cached configurations. This significantly reduces the number of unique tuning runs needed.
  • Config pruning: Supports early config pruning via early_config_prune, performance model-based pruning via perf_model, and top-k selection via top_k. The prune_configs() method applies these in sequence.
  • Reset hooks: Supports zeroing out specified tensors between configuration benchmark runs via reset_to_zero.

The run() method checks the cache for a previously tuned configuration, and if not found, benchmarks all pruned configs and selects the fastest. The _bench() method uses triton.testing.do_bench with quantile reporting (median, 20th, 80th percentiles).

autotune(): A decorator function that wraps a Triton JIT kernel with the Autotuner, matching the API of triton.autotune.

matmul248_kernel_config_pruner(): A config pruning function specific to the GPTQ matmul kernel. It shrinks BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K when the corresponding matrix dimensions (M, N, K) are smaller, and deduplicates resulting configurations.

Usage

This module is used internally by quant_linear.py to autotune the matmul_248_kernel Triton kernel. The autotune decorator is applied with eight different tile configurations and the matmul248_kernel_config_pruner to efficiently select the best configuration for the given matrix dimensions at runtime.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/gptq/custom_autotune.py
  • Lines: 1-252

Signature

class Autotuner(triton.KernelInterface):
    def __init__(self, fn, arg_names, configs, key, reset_to_zero,
                 prune_configs_by: Dict = None, nearest_power_of_two: bool = False)
    def _bench(self, *args, config, **meta)
    def run(self, *args, **kwargs)
    def prune_configs(self, kwargs)
    def warmup(self, *args, **kwargs)

def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False)

def matmul248_kernel_config_pruner(configs, nargs)

Import

from lorax_server.utils.gptq.custom_autotune import autotune, matmul248_kernel_config_pruner

I/O Contract

Inputs

Name Type Required Description
configs list[triton.Config] Yes List of Triton kernel configurations to evaluate
key list[str] Yes Argument names whose values determine when to re-tune (e.g., ["M", "N", "K"])
prune_configs_by Dict No Dictionary with optional keys: early_config_prune, perf_model, top_k
reset_to_zero list[str] No Argument names of tensors to zero before each benchmark run
nearest_power_of_two bool No Whether to round key arguments to nearest power of two for cache lookup

Outputs

Name Type Description
Autotuner Autotuner A wrapped kernel interface that automatically selects the best configuration at runtime
pruned_configs generator[triton.Config] From matmul248_kernel_config_pruner: yields pruned configs with adjusted block sizes

Usage Examples

# Internal usage in quant_linear.py
from lorax_server.utils.gptq import custom_autotune

@custom_autotune.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
                       num_stages=4, num_warps=4),
        # ... more configs
    ],
    key=["M", "N", "K"],
    nearest_power_of_two=True,
    prune_configs_by={
        "early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
        "perf_model": None,
        "top_k": None,
    },
)
@triton.jit
def matmul_248_kernel(a_ptr, b_ptr, c_ptr, ...):
    ...

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment