Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Predibase Lorax LoRA Kernel Selection By Rank

From Leeroopedia



Knowledge Sources
Domains Optimization, LLMs
Last Updated 2026-02-08 02:30 GMT

Overview

Rank-based kernel dispatch strategy that selects between SGMV (rank < 16), custom CUTLASS (rank 16-128), and BGMV (decode-phase) LoRA kernels based on adapter rank to maximize GPU throughput.

Description

LoRAX supports multiple LoRA kernel implementations optimized for different adapter rank ranges. The kernel selection is governed by rank thresholds defined in the Punica wrapper. Additionally, tensors are padded to the nearest multiple of the SGMV block size (16) for alignment, and the minimum rank for efficient SGMV operation scales with tensor parallelism world size.

The three kernel tiers are:

  • SGMV (Segmented Gather Matrix-Vector): Batched LoRA for prefill phase. Handles variable-length sequences with segment-based indexing. Efficient for ranks 8-15 but not optimized for larger ranks.
  • Custom CUTLASS kernels: Optimized CUDA kernels for ranks 16-128. Require transposed weight tensors (`orient_for_rank` transposes in this range). Peak performance tier.
  • BGMV (Batched Gather Matrix-Vector): Used during decode phase within CUDA graphs. Hard limit of rank 128 (BGMV_MAX_RANK).

Usage

This heuristic applies when loading LoRA adapters and dispatching LoRA kernel operations during inference. The rank of the loaded adapter determines which kernel path is used. Practitioners should be aware of these thresholds when training LoRA adapters for deployment on LoRAX:

  • Rank < 8: Falls back to generic loop (slowest)
  • Rank 8-15: Uses SGMV but below optimal range
  • Rank 16-128: Optimal performance with CUTLASS kernels
  • Rank > 128: Not supported by BGMV; will be padded/capped

The Insight (Rule of Thumb)

  • Action: Train LoRA adapters with rank 16, 32, or 64 for optimal LoRAX serving performance.
  • Value: Rank 16 = minimum for CUTLASS optimization. Rank 64 = default CUDA graph max. Rank 128 = absolute BGMV hard limit.
  • Trade-off: Higher ranks give better model quality but consume more VRAM per adapter and increase kernel computation time. Ranks below 16 fall to slower SGMV kernels. Ranks above 128 are unsupported.
  • Tensor Parallelism: Minimum effective rank scales with world_size. For 2-GPU TP, min_rank = 16 (8 * 2).

Reasoning

The kernel selection is driven by hardware efficiency at different matrix sizes:

  • Small ranks (< 8): Matrix operations too small for efficient GPU parallelism. Generic loops are acceptable.
  • Medium ranks (8-15): SGMV provides batched operation but the matrices are still small for CUTLASS tile sizes.
  • Optimal ranks (16-128): CUTLASS kernels use optimized tile sizes that match GPU warp/thread block dimensions. Weight tensors are transposed to match the kernel's expected layout.
  • Over-large ranks (> 128): The BGMV kernel has a hard-coded configuration limit of 128. This is a compile-time constant in the CUDA kernel headers.

Rank padding to multiples of SGMV_BLOCK_SIZE (16) ensures cache-line alignment and efficient memory access patterns.

Code evidence from `server/lorax_server/utils/punica.py:35-86`:

MIN_SGMV_RANK = 8
MIN_RANK_CUSTOM = 16
MAX_RANK_CUSTOM = 128
SGMV_BLOCK_SIZE = 16
BGMV_MAX_RANK = 128

def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
    # tensor parallelism will result in effective rank being divided by world_size,
    # so we need to scale the min rank to offset that effect
    min_rank = MIN_SGMV_RANK * world_size
    return pad_to_min_rank(t, dim, min_rank)

def use_cutlass_shrink(lora_rank: int) -> bool:
    return lora_rank < MIN_RANK_CUSTOM

def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
    if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
        return t.transpose(0, 1)
    return t

Punica kernel import fallback from `server/lorax_server/utils/punica.py:20-27`:

try:
    import punica_kernels as _kernels
    HAS_SGMV = not bool(int(os.environ.get("DISABLE_SGMV", "0")))
except ImportError:
    warnings.warn("Could not import SGMV kernel from Punica, falling back to loop.")
    _kernels = None
    HAS_SGMV = False

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment