Heuristic:Predibase Lorax GPU Sampling Optimization
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs |
| Last Updated | 2026-02-08 02:30 GMT |
Overview
Collection of GPU-side token sampling optimizations: Gumbel-max trick to avoid GPU-CPU synchronization, CUDA graph-traced warping for temperature/top-k/top-p, and no-op warper elimination for common parameter values.
Description
Token sampling (selecting the next token from a probability distribution) is a critical latency bottleneck in autoregressive generation because naive implementations require GPU-to-CPU data transfer for `torch.multinomial`. LoRAX uses three key optimizations:
- Gumbel-Max Trick: Instead of calling `torch.multinomial` (which syncs GPU to CPU), sample from an exponential distribution, divide probabilities by the samples, and take argmax. This is mathematically equivalent to multinomial sampling but stays entirely on GPU.
- CUDA Graph-Traced Warping: The `StaticWarper` class traces temperature/top-k/top-p/typical-p warping into a CUDA graph on first use. Subsequent calls just copy new logits in and replay the graph, eliminating CPU-side branching.
- No-op Warper Elimination: Warpers are skipped for identity values (temperature=1.0, top_k=0, top_p=1.0, typical_p=1.0). The `has_warpers` check prevents allocating unnecessary processor objects.
Additionally, temperature=0 is treated as greedy decoding regardless of the `do_sample` flag, with a warning logged to the user.
Usage
These optimizations apply automatically during token generation in every inference request. No user configuration needed. However, understanding the behavior helps when debugging sampling results:
- Temperature=0 always produces greedy (deterministic) output, even if `do_sample=True`
- Temperature=1.0 with top_k=0 and top_p=1.0 produces unwarped sampling
- CUDA graph warping is cached per unique (temperature, top_k, top_p, typical_p) tuple (LRU cache of size 10)
The Insight (Rule of Thumb)
- Action: Let LoRAX handle sampling internally. Avoid setting redundant parameters (temperature=1.0 is a no-op, top_k=0 means no top-k filtering).
- Value: Gumbel-max trick eliminates one GPU-CPU sync per token (saves ~10-50 microseconds per token on high-latency PCIe connections).
- Trade-off: CUDA graph warping uses a fixed tensor size. The LRU cache (size=10) means at most 10 unique warper configurations are cached simultaneously; the 11th evicts the least-recently-used graph.
- Temperature=0: Always deterministic. LoRAX logs a warning if `do_sample=True` with `temperature=0`.
Reasoning
GPU-CPU synchronization is one of the most expensive operations in GPU computing (~5-50 microseconds depending on PCIe generation). During autoregressive generation, this sync happens once per token when using `torch.multinomial`. For a 512-token generation at 50 tokens/sec, this adds 2.5-25ms of pure sync overhead.
The Gumbel-max trick works because: if X ~ Exponential(1), then argmax(log(p_i) - log(X_i)) = argmax(p_i / X_i) follows the Categorical distribution with probabilities p_i. This is a well-known result in extreme value theory.
CUDA graph warping amortizes the cost of Python-side warper dispatch. After first call, the warpers run as a single GPU operation with no Python overhead.
Code evidence for Gumbel-max from `server/lorax_server/utils/logits_process.py:541-544` (approximate):
probs = torch.nn.functional.softmax(logits, -1)
# Avoid GPU<->CPU sync done by torch.multinomial
q = torch.empty_like(probs).exponential_(1, generator=self.generator)
return probs.div_(q).argmax()
CUDA graph warping from `server/lorax_server/utils/logits_process.py:52-70`:
def __call__(self, scores):
if torch.cuda.is_available():
if self.cuda_graph is None:
self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph, pool=mempool):
local_scores = self.static_scores
for warper in self.warpers:
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
self.static_next_logprob = torch.log_softmax(self.static_warped_scores, -1)
self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
Temperature=0 handling from `server/lorax_server/utils/tokens.py:88-92`:
# do not sample if temperature is 0, even if do_sample flag is set True
if sampling and temperature == 0:
sampling = False
warnings.warn("Temperature is set to 0, token sampling will be disabled")
No-op warper elimination from `server/lorax_server/utils/logits_process.py:37-45`:
if temperature is not None and temperature != 1.0 and temperature != 0:
self.warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
self.warpers.append(TopKLogitsWarper(top_k=top_k))
if top_p is not None and top_p < 1.0:
self.warpers.append(TopPLogitsWarper(top_p=top_p))