Implementation:Predibase Lorax BGMV Expand Slice Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) expand operation with slice offset support, enabling LoRA B weight application to a specific slice of the output tensor during multi-tenant LoRA adapter inference.
Description
The BGMV expand slice kernel extends the standard BGMV expand operation by adding a slice_offset parameter that controls where in the output tensor the results are written. Mathematically, for each request i in the batch, it computes output[i][slice_offset:slice_offset+slice_size] += input[i] * lora_B[idx[i]] where idx[i] is the LoRA adapter index (or skips if idx[i] == -1).
This sliced variant is essential for models where a single linear layer produces multiple logical outputs (e.g., QKV projections or gate/up projections in MLP layers). Each LoRA adapter's B matrix may correspond to only one slice of the full output, and this kernel writes results to the correct offset within the combined output tensor. Like the base BGMV expand kernel, it uses GroupGEMV with SPLIT_N for parallelism, supports mixed-precision type casting, and is based on the Punica paper (Chen et al., 2023).
Usage
This kernel is invoked during LoRA adapter inference in the decode phase when applying LoRA B matrices to sliced output tensors. It is used for fused linear layers where multiple weight matrices (e.g., query, key, value) share a single output buffer, and each LoRA adapter's contribution must be written to the correct slice offset within that buffer.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/bgmv_expand_slice.py - Lines: 1-179
Signature
@triton.jit
def _bgmv_expand_slice_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
slice_offset,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SPLIT_N: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
@torch.inference_mode()
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
override_config: Optional[Dict[str, int]] = None,
):
Import
from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Input tensor of shape (batch_size, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection. |
| lora_b_weights | torch.Tensor | Yes | LoRA B weight matrices of shape (lora_num, slice_size, rank) or (lora_num, 1, slice_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights for the specific slice. |
| output_tensor | torch.Tensor | Yes | Output tensor of shape (batch_size, full_hidden_size), modified in-place at the designated slice region. Must be contiguous. |
| lora_indices_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied. |
| slice_offset | int | Yes | Starting column offset in the output tensor where results are written. |
| slice_size | int | Yes | Size of the output slice. Must equal lora_b_weights.size(-2). |
| add_inputs | bool | No | Whether to add results to existing output values (default True) or overwrite them. |
| override_config | Optional[Dict[str, int]] | No | Override for Triton grid configuration. Defaults to auto-selected config via get_lora_op_configs. |
Outputs
| Name | Type | Description |
|---|---|---|
| output_tensor | torch.Tensor | Modified in-place. For each row i, output_tensor[i][slice_offset:slice_offset+slice_size] is updated with the result of inputs[i] multiplied by lora_b_weights[lora_indices[i]]. |
Usage Examples
# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice
# Apply LoRA B weights to a specific slice of the fused QKV output
bgmv_expand_slice(
inputs=lora_a_output, # (batch_size, rank)
lora_b_weights=lora_b_q, # (num_loras, head_dim, rank)
output_tensor=qkv_output, # (batch_size, 3*head_dim), modified in-place
lora_indices_tensor=indices, # (batch_size,)
slice_offset=0, # Write to Q portion
slice_size=head_dim,
add_inputs=True,
)