Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax BGMV Expand Slice Kernel

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) expand operation with slice offset support, enabling LoRA B weight application to a specific slice of the output tensor during multi-tenant LoRA adapter inference.

Description

The BGMV expand slice kernel extends the standard BGMV expand operation by adding a slice_offset parameter that controls where in the output tensor the results are written. Mathematically, for each request i in the batch, it computes output[i][slice_offset:slice_offset+slice_size] += input[i] * lora_B[idx[i]] where idx[i] is the LoRA adapter index (or skips if idx[i] == -1).

This sliced variant is essential for models where a single linear layer produces multiple logical outputs (e.g., QKV projections or gate/up projections in MLP layers). Each LoRA adapter's B matrix may correspond to only one slice of the full output, and this kernel writes results to the correct offset within the combined output tensor. Like the base BGMV expand kernel, it uses GroupGEMV with SPLIT_N for parallelism, supports mixed-precision type casting, and is based on the Punica paper (Chen et al., 2023).

Usage

This kernel is invoked during LoRA adapter inference in the decode phase when applying LoRA B matrices to sliced output tensors. It is used for fused linear layers where multiple weight matrices (e.g., query, key, value) share a single output buffer, and each LoRA adapter's contribution must be written to the correct slice offset within that buffer.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/ops/bgmv_expand_slice.py
  • Lines: 1-179

Signature

@triton.jit
def _bgmv_expand_slice_kernel(
    input_ptr,
    lora_ptr,
    out_ptr,
    N,
    K,
    lora_indices,
    xm_stride,
    xk_stride,
    l0_stride,
    lora_k_stride,
    lora_n_stride,
    cm_stride,
    cn_stride,
    slice_offset,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    SPLIT_N: tl.constexpr,
    EVEN_K: tl.constexpr,
    ADD_INPUTS: tl.constexpr,
    CAST_TYPE: tl.constexpr,
):

@torch.inference_mode()
def bgmv_expand_slice(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    slice_offset: int,
    slice_size: int,
    add_inputs: bool = True,
    override_config: Optional[Dict[str, int]] = None,
):

Import

from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Input tensor of shape (batch_size, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection.
lora_b_weights torch.Tensor Yes LoRA B weight matrices of shape (lora_num, slice_size, rank) or (lora_num, 1, slice_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights for the specific slice.
output_tensor torch.Tensor Yes Output tensor of shape (batch_size, full_hidden_size), modified in-place at the designated slice region. Must be contiguous.
lora_indices_tensor torch.Tensor Yes Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied.
slice_offset int Yes Starting column offset in the output tensor where results are written.
slice_size int Yes Size of the output slice. Must equal lora_b_weights.size(-2).
add_inputs bool No Whether to add results to existing output values (default True) or overwrite them.
override_config Optional[Dict[str, int]] No Override for Triton grid configuration. Defaults to auto-selected config via get_lora_op_configs.

Outputs

Name Type Description
output_tensor torch.Tensor Modified in-place. For each row i, output_tensor[i][slice_offset:slice_offset+slice_size] is updated with the result of inputs[i] multiplied by lora_b_weights[lora_indices[i]].

Usage Examples

# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_expand_slice import bgmv_expand_slice

# Apply LoRA B weights to a specific slice of the fused QKV output
bgmv_expand_slice(
    inputs=lora_a_output,           # (batch_size, rank)
    lora_b_weights=lora_b_q,        # (num_loras, head_dim, rank)
    output_tensor=qkv_output,       # (batch_size, 3*head_dim), modified in-place
    lora_indices_tensor=indices,    # (batch_size,)
    slice_offset=0,                 # Write to Q portion
    slice_size=head_dim,
    add_inputs=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment