Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Punica Ops Bindings

From Leeroopedia


Knowledge Sources
Domains LoRA, GPU_Kernels
Last Updated 2026-02-08 00:00 GMT

Overview

This C++ file provides the PyTorch/pybind11 extension bindings for the Punica CUDA kernels, exposing batched LoRA operations (BGMV, SGMV), FlashInfer paged attention, and RMS normalization to Python.

Description

This extension module is the central binding layer for LoRAX's custom CUDA kernels that enable efficient multi-adapter batched inference. It exposes nine operations through pybind11, organized into four categories:

FlashInfer paged attention: batch_prefill performs batched prefill attention over paged KV caches with variable-length query sequences; batch_decode performs batched single-token decode attention; init_kv initializes paged KV cache entries from key/value tensors with sequence-length indexing; and append_kv appends single-token key/value pairs into existing paged KV cache pages.

BGMV (Batched Gather Matrix-Vector): dispatch_bgmv performs batched gather-matrix-vector multiplication for LoRA adapter application, dispatching to pre-compiled kernels based on input/output feature dimensions. It uses a pack_u32 utility to switch on dimension pairs and supports both fp16 and bf16.

SGMV (Segmented Gather Matrix-Vector): dispatch_sgmv_cutlass uses CUTLASS-based kernels for segmented gather-matrix-vector operations with variable segment lengths; dispatch_sgmv_shrink provides an optimized variant for the shrinking (down-projection) side of LoRA with compile-time output dimension specialization. Both expose sgmv_tmp_size for temporary buffer sizing.

RMS normalization: dispatch_rms_norm performs fused RMS normalization on 2D input tensors with a 1D weight vector.

All functions include comprehensive input validation via CHECK_INPUT, CHECK_DIM, CHECK_SHAPE, and CHECK_EQ macros, and use the DISPATCH_TORCH_DTYPE macro to dispatch to fp16 or bf16 template specializations.

Usage

This extension is the workhorse of LoRAX's multi-adapter serving. During inference, the LoRA adapter weights are applied to hidden states via BGMV or SGMV kernels, allowing multiple different LoRA adapters to be served simultaneously in a single batch. The FlashInfer operations manage the paged KV cache for efficient memory use, and the RMS norm kernel provides a fused normalization step.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/punica_kernels/punica_kernels/punica_ops.cc
  • Lines: 1-459

Signature

// FlashInfer paged attention
void batch_prefill(torch::Tensor o, torch::Tensor q, torch::Tensor qo_indptr,
                   torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
                   torch::Tensor last_page_offset, torch::Tensor tmpbuf,
                   int num_layers, int layer_idx, int num_kv_heads, int page_size);

void batch_decode(torch::Tensor o, torch::Tensor q, torch::Tensor kv_ptrs,
                  torch::Tensor kv_indptr, torch::Tensor last_page_offset,
                  torch::Tensor tmpbuf, int num_layers, int layer_idx,
                  int num_kv_heads, int page_size);

void init_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
             torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v,
             torch::Tensor seqlen_indptr, int num_layers, int layer_idx,
             int num_kv_heads, int page_size);

void append_kv(torch::Tensor kv_ptrs, torch::Tensor kv_indptr,
               torch::Tensor last_page_offset, torch::Tensor k, torch::Tensor v,
               int num_layers, int layer_idx, int num_kv_heads, int page_size);

// BGMV - Batched LoRA application
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
                   torch::Tensor indicies, int64_t layer_idx, float scale);

// SGMV - Segmented LoRA application
void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
                           torch::Tensor s_start, torch::Tensor s_end,
                           torch::Tensor tmp, int layer_idx);

void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
                          torch::Tensor s_start, torch::Tensor s_end,
                          torch::Tensor tmp, int layer_idx);

// RMS Normalization
void dispatch_rms_norm(torch::Tensor output, torch::Tensor input,
                       torch::Tensor weight, float epsilon);

// pybind11 module registration
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("batch_prefill", &batch_prefill, "");
    m.def("batch_decode", &batch_decode, "");
    m.def("init_kv", &init_kv, "");
    m.def("append_kv", &append_kv, "");
    m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
    m.def("sgmv_cutlass", &dispatch_sgmv_cutlass, "");
    m.def("sgmv_cutlass_tmp_size", &sgmv_tmp_size, "");
    m.def("sgmv_shrink", &dispatch_sgmv_shrink, "");
    m.def("rms_norm", &dispatch_rms_norm, "");
}

Import

#include <c10/cuda/CUDAStream.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>

#include "bgmv/bgmv_config.h"
#include "flashinfer_adapter/flashinfer_config.h"
#include "rms_norm/rms_norm.h"
#include "sgmv/sgmv.h"
#include "sgmv_flashinfer/sgmv_config.h"

I/O Contract

Inputs

Name Type Required Description
o torch::Tensor (half/bf16) [B, N, D] Yes Output tensor for attention results (batch_prefill/batch_decode)
q torch::Tensor (half/bf16) [B, N, D] Yes Query tensor for attention operations
qo_indptr torch::Tensor (int32) [B+1] Yes Query/output index pointers for variable-length prefill sequences
kv_ptrs torch::Tensor (int64) [num_pages] Yes Pointers to paged KV cache blocks, each pointing to a [L, 2, N, P, D] tensor
kv_indptr torch::Tensor (int32) [B+1] Yes KV page index pointers per batch element
last_page_offset torch::Tensor (int32) [B] Yes Offset within the last page for each batch element
tmpbuf torch::Tensor Yes Temporary buffer (at least 64 MB)
num_layers int Yes Total number of model layers
layer_idx int Yes Current layer index for KV cache addressing
num_kv_heads int Yes Number of key/value attention heads
page_size int Yes Number of tokens per KV cache page
k, v torch::Tensor (half/bf16) Yes Key and value tensors for KV cache initialization/append
seqlen_indptr torch::Tensor (int32) [B+1] Yes Sequence length index pointers (init_kv only)
y torch::Tensor (half/bf16) [B, h_out] Yes Output tensor for BGMV/SGMV LoRA operations
x torch::Tensor (half/bf16) [B, h_in] Yes Input hidden states for BGMV/SGMV
w_ptr torch::Tensor (int64/ptr) Yes Pointers to per-adapter LoRA weight matrices
indicies torch::Tensor (int64) [B] Yes Per-token adapter index for BGMV dispatch
s_start, s_end torch::Tensor (int32) [num_problems] Yes Segment start/end indices for SGMV operations
tmp torch::Tensor (uint8) Yes Temporary workspace for SGMV (sized via sgmv_cutlass_tmp_size or 8MB for shrink)
scale float Yes Scaling factor applied to BGMV output
output torch::Tensor (half/bf16) [rows, columns] Yes Output tensor for RMS normalization
input torch::Tensor (half/bf16) [rows, columns] Yes Input tensor for RMS normalization
weight torch::Tensor (half/bf16) [columns] Yes RMS normalization weight vector
epsilon float Yes RMS normalization epsilon for numerical stability

Outputs

Name Type Description
o (mutated) torch::Tensor (half/bf16) Attention output written in-place by batch_prefill/batch_decode
y (mutated) torch::Tensor (half/bf16) LoRA output accumulated in-place by BGMV/SGMV (y += scale * x @ W_adapter)
output (mutated) torch::Tensor (half/bf16) RMS-normalized result written in-place
sgmv_cutlass_tmp_size int Required temporary buffer size in bytes for SGMV CUTLASS operations

Usage Examples

// From Python via the compiled extension:
import punica_kernels

# Batched decode attention with paged KV cache
punica_kernels.batch_decode(o, q, kv_ptrs, kv_indptr, last_page_offset,
                            tmpbuf, num_layers, layer_idx, num_kv_heads, page_size)

# Apply LoRA adapters via BGMV (one adapter per token)
punica_kernels.dispatch_bgmv(y, x, w_ptr, indices, layer_idx, scale=1.0)

# Apply LoRA adapters via SGMV (variable-length segments)
tmp_size = punica_kernels.sgmv_cutlass_tmp_size(num_adapters)
punica_kernels.sgmv_cutlass(y, x, w_ptr, s_start, s_end, tmp, layer_idx)

# Fused RMS normalization
punica_kernels.rms_norm(output, input, weight, epsilon=1e-6)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment