Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax GPTQ Custom Autotune

From Leeroopedia
Revision as of 16:20, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_GPTQ_Custom_Autotune.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Quantization, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Provides a custom Triton autotuner for GPTQ quantized matrix multiplication kernels, with reduced benchmarking iterations and configuration pruning based on matrix dimensions.

Description

This module is adapted from the GPTQ-triton project and provides a modified version of the Triton autotuner optimized for quantized inference workloads. It contains:

Autotuner: A class extending triton.KernelInterface that manages kernel configuration selection. Key features include:

  • Cache-based tuning: Stores the best configuration per key tuple (typically M, N, K dimensions), with an option to round keys to the nearest power of two (nearest_power_of_two) to reduce the number of unique tuning runs.
  • Reduced benchmarking: Uses 40 repetitions instead of the default 100, which was found to give sufficiently accurate results while reducing tuning time.
  • Configuration pruning: Supports three pruning strategies: an early_config_prune function applied before benchmarking, a perf_model for performance prediction, and top_k to limit the number of configs actually benchmarked.
  • _bench method: Runs a single config benchmark using triton.testing.do_bench with quantile reporting (0.5, 0.2, 0.8). Returns infinity on OutOfResources errors.
  • warmup method: Warms up all pruned configurations without benchmarking.

autotune: A decorator function that wraps a triton.jit kernel with the Autotuner class. It accepts configs (list of triton.Config), key (argument names that trigger re-evaluation), prune_configs_by (pruning strategy dict), reset_to_zero (tensor arguments to zero before each trial), and nearest_power_of_two.

matmul248_kernel_config_pruner: A generator function that shrinks BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K when the corresponding matrix dimensions (M, N, K) are smaller than the configured block sizes. It deduplicates configurations to avoid redundant benchmarking.

Usage

This module is used as a decorator on the matmul_248_kernel Triton kernel in quant_linear.py. It automatically tunes kernel launch parameters for each unique combination of matrix dimensions encountered during inference.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/layers/gptq/custom_autotune.py
  • Lines: 1-252

Signature

class Autotuner(triton.KernelInterface):
    def __init__(self, fn, arg_names, configs, key, reset_to_zero,
                 prune_configs_by: Dict = None, nearest_power_of_two: bool = False):

def autotune(configs, key, prune_configs_by=None, reset_to_zero=None,
             nearest_power_of_two=False):

def matmul248_kernel_config_pruner(configs, nargs):

Import

from lorax_server.layers.gptq.custom_autotune import autotune, matmul248_kernel_config_pruner

I/O Contract

Inputs (autotune decorator)

Name Type Required Description
configs list[triton.Config] Yes List of Triton kernel configurations to evaluate
key list[str] Yes Argument names whose value changes trigger re-evaluation of all configs
prune_configs_by dict or None No Dict with 'early_config_prune', 'perf_model', and 'top_k' for config pruning
reset_to_zero list[str] or None No Argument names whose tensors are zeroed before each trial
nearest_power_of_two bool No Whether to round key values to nearest power of two for caching

Outputs

Name Type Description
decorator Callable A decorator that wraps a Triton JIT function with the Autotuner

Usage Examples

# Used as a decorator on Triton GPTQ kernels
from lorax_server.layers.gptq import custom_autotune

@custom_autotune.autotune(
    configs=[triton.Config({"BLOCK_SIZE_M": 64, ...}, num_stages=4, num_warps=4)],
    key=["M", "N", "K"],
    nearest_power_of_two=True,
    prune_configs_by={"early_config_prune": custom_autotune.matmul248_kernel_config_pruner},
)
@triton.jit
def my_kernel(...):
    ...

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment