Implementation:Predibase Lorax Punica Ops Bindings
| Knowledge Sources | |
|---|---|
| Domains | LoRA, GPU_Kernels |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
This C++ file provides the PyTorch/pybind11 extension bindings for the Punica CUDA kernels, exposing batched LoRA operations (BGMV, SGMV), FlashInfer paged attention, and RMS normalization to Python.
Description
This extension module is the central binding layer for LoRAX's custom CUDA kernels that enable efficient multi-adapter batched inference. It exposes nine operations through pybind11, organized into four categories:
FlashInfer paged attention: batch_prefill performs batched prefill attention over paged KV caches with variable-length query sequences; batch_decode performs batched single-token decode attention; init_kv initializes paged KV cache entries from key/value tensors with sequence-length indexing; and append_kv appends single-token key/value pairs into existing paged KV cache pages.
BGMV (Batched Gather Matrix-Vector): dispatch_bgmv performs batched gather-matrix-vector multiplication for LoRA adapter application, dispatching to pre-compiled kernels based on input/output feature dimensions. It uses a pack_u32 utility to switch on dimension pairs and supports both fp16 and bf16.
SGMV (Segmented Gather Matrix-Vector): dispatch_sgmv_cutlass uses CUTLASS-based kernels for segmented gather-matrix-vector operations with variable segment lengths; dispatch_sgmv_shrink provides an optimized variant for the shrinking (down-projection) side of LoRA with compile-time output dimension specialization. Both expose sgmv_tmp_size for temporary buffer sizing.
RMS normalization: dispatch_rms_norm performs fused RMS normalization on 2D input tensors with a 1D weight vector.
All functions include comprehensive input validation via CHECK_INPUT, CHECK_DIM, CHECK_SHAPE, and CHECK_EQ macros, and use the DISPATCH_TORCH_DTYPE macro to dispatch to fp16 or bf16 template specializations.
Usage
This extension is the workhorse of LoRAX's multi-adapter serving. During inference, the LoRA adapter weights are applied to hidden states via BGMV or SGMV kernels, allowing multiple different LoRA adapters to be served simultaneously in a single batch. The FlashInfer operations manage the paged KV cache for efficient memory use, and the RMS norm kernel provides a fused normalization step.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/punica_kernels/punica_kernels/punica_ops.cc - Lines: 1-459
Signature
// FlashInfer paged attention
void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
torch::Tensor last_page_offset, torch::Tensor tmpbuf,
int num_layers, int layer_idx, int num_kv_heads, int page_size);
void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_ptrs,
torch::Tensor kv_indptr, torch::Tensor last_page_offset,
torch::Tensor tmpbuf, int num_layers, int layer_idx,
int num_kv_heads, int page_size);
void init_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v,
torch::Tensor seqlen_indptr, int num_layers, int layer_idx,
int num_kv_heads, int page_size);
void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v,
int num_layers, int layer_idx, int num_kv_heads, int page_size);
// BGMV - Batched LoRA application
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
torch::Tensor indicies, int64_t layer_idx, float scale);
// SGMV - Segmented LoRA application
void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
torch::Tensor s_start, torch::Tensor s_end,
torch::Tensor tmp, int layer_idx);
void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
torch::Tensor s_start, torch::Tensor s_end,
torch::Tensor tmp, int layer_idx);
// RMS Normalization
void dispatch_rms_norm(torch::Tensor output, torch::Tensor input,
torch::Tensor weight, float epsilon);
// pybind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("batch_prefill", &batch_prefill, "");
m.def("batch_decode", &batch_decode, "");
m.def("init_kv", &init_kv, "");
m.def("append_kv", &append_kv, "");
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("sgmv_cutlass", &dispatch_sgmv_cutlass, "");
m.def("sgmv_cutlass_tmp_size", &sgmv_tmp_size, "");
m.def("sgmv_shrink", &dispatch_sgmv_shrink, "");
m.def("rms_norm", &dispatch_rms_norm, "");
}
Import
#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include "bgmv/bgmv_config.h"
#include "flashinfer_adapter/flashinfer_config.h"
#include "rms_norm/rms_norm.h"
#include "sgmv/sgmv.h"
#include "sgmv_flashinfer/sgmv_config.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| o | torch::Tensor (half/bf16) [B, N, D] |
Yes | Output tensor for attention results (batch_prefill/batch_decode) |
| q | torch::Tensor (half/bf16) [B, N, D] |
Yes | Query tensor for attention operations |
| qo_indptr | torch::Tensor (int32) [B+1] |
Yes | Query/output index pointers for variable-length prefill sequences |
| kv_ptrs | torch::Tensor (int64) [num_pages] |
Yes | Pointers to paged KV cache blocks, each pointing to a [L, 2, N, P, D] tensor
|
| kv_indptr | torch::Tensor (int32) [B+1] |
Yes | KV page index pointers per batch element |
| last_page_offset | torch::Tensor (int32) [B] |
Yes | Offset within the last page for each batch element |
| tmpbuf | torch::Tensor |
Yes | Temporary buffer (at least 64 MB) |
| num_layers | int |
Yes | Total number of model layers |
| layer_idx | int |
Yes | Current layer index for KV cache addressing |
| num_kv_heads | int |
Yes | Number of key/value attention heads |
| page_size | int |
Yes | Number of tokens per KV cache page |
| k, v | torch::Tensor (half/bf16) |
Yes | Key and value tensors for KV cache initialization/append |
| seqlen_indptr | torch::Tensor (int32) [B+1] |
Yes | Sequence length index pointers (init_kv only) |
| y | torch::Tensor (half/bf16) [B, h_out] |
Yes | Output tensor for BGMV/SGMV LoRA operations |
| x | torch::Tensor (half/bf16) [B, h_in] |
Yes | Input hidden states for BGMV/SGMV |
| w_ptr | torch::Tensor (int64/ptr) |
Yes | Pointers to per-adapter LoRA weight matrices |
| indicies | torch::Tensor (int64) [B] |
Yes | Per-token adapter index for BGMV dispatch |
| s_start, s_end | torch::Tensor (int32) [num_problems] |
Yes | Segment start/end indices for SGMV operations |
| tmp | torch::Tensor (uint8) |
Yes | Temporary workspace for SGMV (sized via sgmv_cutlass_tmp_size or 8MB for shrink)
|
| scale | float |
Yes | Scaling factor applied to BGMV output |
| output | torch::Tensor (half/bf16) [rows, columns] |
Yes | Output tensor for RMS normalization |
| input | torch::Tensor (half/bf16) [rows, columns] |
Yes | Input tensor for RMS normalization |
| weight | torch::Tensor (half/bf16) [columns] |
Yes | RMS normalization weight vector |
| epsilon | float |
Yes | RMS normalization epsilon for numerical stability |
Outputs
| Name | Type | Description |
|---|---|---|
| o (mutated) | torch::Tensor (half/bf16) |
Attention output written in-place by batch_prefill/batch_decode |
| y (mutated) | torch::Tensor (half/bf16) |
LoRA output accumulated in-place by BGMV/SGMV (y += scale * x @ W_adapter)
|
| output (mutated) | torch::Tensor (half/bf16) |
RMS-normalized result written in-place |
| sgmv_cutlass_tmp_size | int |
Required temporary buffer size in bytes for SGMV CUTLASS operations |
Usage Examples
// From Python via the compiled extension:
import punica_kernels
# Batched decode attention with paged KV cache
punica_kernels.batch_decode(o, q, kv_ptrs, kv_indptr, last_page_offset,
tmpbuf, num_layers, layer_idx, num_kv_heads, page_size)
# Apply LoRA adapters via BGMV (one adapter per token)
punica_kernels.dispatch_bgmv(y, x, w_ptr, indices, layer_idx, scale=1.0)
# Apply LoRA adapters via SGMV (variable-length segments)
tmp_size = punica_kernels.sgmv_cutlass_tmp_size(num_adapters)
punica_kernels.sgmv_cutlass(y, x, w_ptr, s_start, s_end, tmp, layer_idx)
# Fused RMS normalization
punica_kernels.rms_norm(output, input, weight, epsilon=1e-6)