Implementation:Predibase Lorax BGMV Expand Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) expand operation for applying LoRA B (up-projection) weights during multi-tenant LoRA adapter inference.
Description
The BGMV expand kernel computes a batched matrix-vector product between a low-rank intermediate representation and per-request LoRA B weight matrices, projecting from rank space back to hidden dimension space. Mathematically, for each request i in the batch, it computes output[i] += input[i] * lora_B[idx[i]] where idx[i] is the LoRA adapter index for the request (or skips computation if idx[i] == -1).
The kernel is based on the Punica paper (Chen et al., 2023) for multi-tenant LoRA serving. It uses a GroupGEMV (Group General Matrix-Vector) approach with a configurable SPLIT_N parameter that partitions the output dimension N across thread blocks to improve performance for large hidden sizes. The kernel supports mixed-precision computation with optional type casting when input tensors are float32 but LoRA weights are float16/bfloat16. An ADD_INPUTS flag controls whether results are accumulated into the output tensor or overwrite it.
Usage
This kernel is invoked during LoRA adapter inference in the decode phase (single-token generation) when applying the LoRA B matrix (expand/up-projection). It is called from the Punica kernel orchestrator to transform the low-rank intermediate activations back to the model's hidden dimension. The BGMV variant is used for batched single-token requests where each request maps to at most one token, as opposed to SGMV which handles variable-length sequences.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/bgmv_expand.py - Lines: 1-167
Signature
@triton.jit
def _bgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
@torch.inference_mode()
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
Import
from lorax_server.utils.ops.bgmv_expand import bgmv_expand
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Input tensor of shape (batch_size, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection. |
| lora_b_weights | torch.Tensor | Yes | LoRA B weight matrices of shape (lora_num, hidden_size, rank) or (lora_num, 1, hidden_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights. |
| output_tensor | torch.Tensor | Yes | Output tensor of shape (batch_size, hidden_size), modified in-place. Must be contiguous. |
| lora_indices_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied. |
| add_inputs | bool | No | Whether to add results to existing output values (default True) or overwrite them. |
| override_config | Optional[Dict[str, int]] | No | Override for Triton grid configuration (BLOCK_N, SPLIT_N). Defaults to auto-selected config via get_lora_op_configs. |
Outputs
| Name | Type | Description |
|---|---|---|
| output_tensor | torch.Tensor | Modified in-place. Each row output_tensor[i] is updated with the result of inputs[i] multiplied by lora_b_weights[lora_indices[i]]. |
Usage Examples
# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_expand import bgmv_expand
# Apply LoRA B (expand) weights to low-rank activations
bgmv_expand(
inputs=lora_a_output, # (batch_size, rank)
lora_b_weights=lora_b_weights, # (num_loras, hidden_size, rank)
output_tensor=output, # (batch_size, hidden_size), modified in-place
lora_indices_tensor=indices, # (batch_size,), -1 for no LoRA
add_inputs=True,
)