Implementation:Predibase Lorax Triton LibEntry
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
A Triton kernel launch optimization decorator that reduces runtime overhead for small GPU kernels by caching compiled kernels and simplifying parameter binding, originally from the FlagGems project.
Description
The LibEntry class implements triton.KernelInterface and provides an optimized kernel launch path that bypasses Triton's default runtime overhead. The standard Triton JITFunction.run method performs several expensive steps on each invocation: parameter binding using Python's inspect module, KernelArg type wrapping, and cache key calculation. For small kernels (especially with smaller models), these CPU-side costs become a significant bottleneck relative to the actual GPU computation time.
LibEntry addresses this by maintaining its own kernel_cache dictionary. On the first invocation with a given set of argument types and values, it compiles the kernel normally through Triton's pipeline (handling Autotuner, Heuristics, and constexpr resolution). On subsequent calls with the same argument signature, it retrieves the compiled kernel from cache and launches it directly, bypassing Triton's parameter processing. The cache key is computed from argument dtypes, pointer alignment (divisibility by 16), constexpr values, and integer value ranges (i32 vs i64 vs u64).
The libentry() factory function returns a decorator that wraps a Triton JIT function with the LibEntry optimization. As noted in the source code, this optimization can be removed when Triton is upgraded to version 3.0.0, which addresses the runtime overhead natively.
Usage
This decorator is applied to SGMV Triton kernels (sgmv_expand, sgmv_expand_slice, sgmv_shrink) to reduce their kernel launch overhead. It is placed above the @triton.jit decorator. The optimization is transparent to callers; the decorated kernel is invoked the same way as a standard Triton kernel.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/libentry.py - Lines: 1-168
Signature
class LibEntry(triton.KernelInterface):
def __init__(self, fn):
def key(self, spec_args, dns_args, const_args):
def run(self, *args, **kwargs):
def libentry():
"""
Decorator for triton library entries.
Reduces Triton runtime overhead for small kernels by caching
compiled kernels and simplifying parameter binding.
"""
Import
from lorax_server.utils.ops.libentry import libentry
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| fn | triton.runtime.JITFunction | Yes | The Triton JIT-compiled kernel function to wrap. Passed implicitly when used as a decorator on a @triton.jit function. |
Outputs
| Name | Type | Description |
|---|---|---|
| LibEntry | triton.KernelInterface | A wrapped kernel that implements the same interface as the original Triton JIT function but with optimized launch caching. |
Usage Examples
# Applied as a decorator above @triton.jit on SGMV kernels
from lorax_server.utils.ops.libentry import libentry
@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
# ... kernel parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# kernel body
...