Implementation:Predibase Lorax BGMV Shrink Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) shrink operation for applying LoRA A (down-projection) weights during multi-tenant LoRA adapter inference.
Description
The BGMV shrink kernel computes a batched matrix-vector product between hidden-dimension activations and per-request LoRA A weight matrices, projecting from hidden dimension space down to rank space. Mathematically, for each request i in the batch, it computes output[i] = scaling * input[i] * lora_A[idx[i]] where idx[i] is the LoRA adapter index (or skips computation if idx[i] == -1).
This kernel is based on the Punica paper (Chen et al., 2023) for multi-tenant LoRA serving. It uses a GroupGEMV approach with a configurable SPLIT_K parameter that partitions the input/hidden dimension K across thread blocks. When SPLIT_K > 1, partial results from each split are combined using tl.atomic_add to avoid race conditions. The kernel accumulates in float32 precision and applies a scaling factor before storing results. The output dimension N (rank) is rounded to the next power of 2 for the block size (BLOCK_N).
Usage
This kernel is invoked during LoRA adapter inference in the decode phase (single-token generation) when applying the LoRA A matrix (shrink/down-projection). It is the first step of the LoRA computation: projecting the model's hidden activations into the low-rank space. The result is then fed into the BGMV expand kernel for the LoRA B up-projection. The BGMV variant handles batched single-token requests.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/bgmv_shrink.py - Lines: 1-149
Signature
@triton.jit
def _bgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
@torch.inference_mode()
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
override_config: Optional[Dict[str, int]] = None,
):
Import
from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Input tensor of shape (batch_size, hidden_size), dtype float16/bfloat16. The hidden activations to be projected into low-rank space. Must match lora_a_weights dtype. |
| lora_a_weights | torch.Tensor | Yes | LoRA A weight matrices of shape (lora_num, rank, hidden_size) or (lora_num, 1, rank, hidden_size), dtype float16/bfloat16. The per-adapter down-projection weights. |
| output_tensor | torch.Tensor | Yes | Output tensor of shape (batch_size, rank), modified in-place. Must be contiguous. |
| lora_indices_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied. |
| scaling | float | No | Scaling factor applied to the result (default 1.0). Typically the LoRA alpha/rank scaling. |
| override_config | Optional[Dict[str, int]] | No | Override for Triton grid configuration (BLOCK_K, SPLIT_K). Defaults to auto-selected config via get_lora_op_configs. |
Outputs
| Name | Type | Description |
|---|---|---|
| output_tensor | torch.Tensor | Modified in-place. Each row output_tensor[i] contains scaling * (inputs[i] dot lora_a_weights[lora_indices[i]]), which is the low-rank projection of the input. |
Usage Examples
# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink
# Apply LoRA A (shrink) weights to project hidden activations to rank space
bgmv_shrink(
inputs=hidden_states, # (batch_size, hidden_size)
lora_a_weights=lora_a_weights, # (num_loras, rank, hidden_size)
output_tensor=lora_a_output, # (batch_size, rank), modified in-place
lora_indices_tensor=indices, # (batch_size,), -1 for no LoRA
scaling=alpha / rank,
)