Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Predibase Lorax CUDA Graph Batch Size Caching

From Leeroopedia




Knowledge Sources
Domains Optimization, GPU_Computing
Last Updated 2026-02-08 02:30 GMT

Overview

Memory-latency optimization that caches CUDA graphs for discrete batch sizes [1, 2, 3, 4, 8, 16, 32, ...] and LoRA ranks [0, 8, 16, 32, 64, 128], reducing decode latency by up to 50% at the cost of 3-10GB additional VRAM.

Description

CUDA graphs capture a sequence of GPU operations (kernels, memory copies) into a reusable executable that bypasses CPU-side launch overhead. LoRAX pre-traces graphs for specific batch size and LoRA rank combinations during server warmup. Instead of tracing a graph for every possible batch size (which would exhaust GPU memory), it uses a quantized set of cached sizes. Incoming requests are padded up to the nearest cached size.

The batch size quantization strategy uses fine granularity for small batches (1, 2, 3, 4, 8, 16) where latency is most sensitive, and coarser 32-element increments for larger batches where throughput matters more.

A critical workaround exists for SGMV kernels: segment sizes must be powers of 2 to avoid kernel hangs at certain batch sizes (observed at batch_size=96). This is suspected to be related to synchronization with the kernel's internal chunk size of 256.

Usage

Use this heuristic when configuring LoRAX for latency-sensitive decode workloads with small to medium batch sizes (< 256). CUDA graphs are most beneficial for:

  • Small LLMs (~1B parameters) that are compute-bound
  • Decode-heavy workloads (not prefill-dominated)
  • Context lengths < 8192
  • Single LoRA rank per batch

Avoid CUDA graphs when memory is the bottleneck or when variable batch sizes cause excessive graph tracing.

The Insight (Rule of Thumb)

  • Action: Configure `LORAX_COMPILE_BATCH_SIZE` to match your expected peak batch size, and `LORAX_COMPILE_MAX_RANK` to match your largest LoRA adapter rank.
  • Value: Default `LORAX_COMPILE_BATCH_SIZE=32`, default `LORAX_COMPILE_MAX_RANK=64`. Graphs are traced for batch sizes [1, 2, 3, 4, 8, 16, 32] x ranks [0, 8, 16, 32, 64].
  • Trade-off: Each graph combination consumes GPU memory. Reducing `LORAX_COMPILE_BATCH_SIZE` saves memory but limits maximum CUDA-graph-accelerated batch size. Requests exceeding the max fall back to eager execution.
  • SGMV Workaround: Segment sizes are padded to next power of 2 to prevent kernel hangs. This wastes some memory but is required for stability.

Reasoning

Each CUDA graph records a fixed computation shape. CPU-side kernel launch overhead (typically 5-10 microseconds per kernel) becomes significant during autoregressive decoding where each token requires launching dozens of kernels. CUDA graphs eliminate this overhead by replaying the entire sequence in a single GPU operation.

The quantized batch size strategy balances two concerns:

  1. Memory: Each (batch_size, rank) combination requires a separate graph allocation. N_batch x N_rank graphs would exhaust memory.
  2. Padding waste: Larger increments waste more compute on padding tokens, but this is acceptable for large batches where per-token latency is less critical.

Memory estimation uses median of 3 samples (discarding the first to account for one-time initialization). Graphs are traced largest-first to establish memory pools that smaller graphs can reuse.

Code evidence from `server/lorax_server/utils/graph.py:29-52`:

MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 256))
COMPILE_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_BATCH_SIZE", 32))
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64))

BATCH_SIZE_INCREMENT = 32
CACHED_BATCH_SIZES = [1, 2, 3, 4, 8, 16] + [
    BATCH_SIZE_INCREMENT * (i + 1)
    for i in range(MAX_BATCH_SIZE // BATCH_SIZE_INCREMENT)
]
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= COMPILE_BATCH_SIZE]

CACHED_MAX_RANKS = [0, 8, 16, 32, 64, 128]
CACHED_MAX_RANKS = [r for r in CACHED_MAX_RANKS if r <= MAX_RANK]

SGMV power-of-2 workaround from `server/lorax_server/utils/graph.py:210-215`:

# WARNING: for some reason the SGMV kernel can hang if we don't use a power of 2
# as the segment size. This is a workaround until we can figure out why.
# Specifically, this issue has been observed with batch_size=96.
# I suspect it is related to synchronization and the chunk size (256) used in the kernel.
segment_size = next_pow_2(batch_size)

Batch size quantization from `server/lorax_server/utils/graph.py:57-68`:

def get_cached_batch_size(batch_size: int) -> int:
    if batch_size == 1: return 1
    if batch_size == 2: return 2
    if batch_size <= 4: return 4
    if batch_size <= 8: return 8
    if batch_size <= 16: return 16
    return (batch_size + BATCH_SIZE_INCREMENT - 1) // BATCH_SIZE_INCREMENT * BATCH_SIZE_INCREMENT

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment