Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax BGMV Expand Kernel

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) expand operation for applying LoRA B (up-projection) weights during multi-tenant LoRA adapter inference.

Description

The BGMV expand kernel computes a batched matrix-vector product between a low-rank intermediate representation and per-request LoRA B weight matrices, projecting from rank space back to hidden dimension space. Mathematically, for each request i in the batch, it computes output[i] += input[i] * lora_B[idx[i]] where idx[i] is the LoRA adapter index for the request (or skips computation if idx[i] == -1).

The kernel is based on the Punica paper (Chen et al., 2023) for multi-tenant LoRA serving. It uses a GroupGEMV (Group General Matrix-Vector) approach with a configurable SPLIT_N parameter that partitions the output dimension N across thread blocks to improve performance for large hidden sizes. The kernel supports mixed-precision computation with optional type casting when input tensors are float32 but LoRA weights are float16/bfloat16. An ADD_INPUTS flag controls whether results are accumulated into the output tensor or overwrite it.

Usage

This kernel is invoked during LoRA adapter inference in the decode phase (single-token generation) when applying the LoRA B matrix (expand/up-projection). It is called from the Punica kernel orchestrator to transform the low-rank intermediate activations back to the model's hidden dimension. The BGMV variant is used for batched single-token requests where each request maps to at most one token, as opposed to SGMV which handles variable-length sequences.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/ops/bgmv_expand.py
  • Lines: 1-167

Signature

@triton.jit
def _bgmv_expand_kernel(
    input_ptr,
    lora_ptr,
    out_ptr,
    N,
    K,
    lora_indices,
    xm_stride,
    xk_stride,
    l0_stride,
    lora_k_stride,
    lora_n_stride,
    cm_stride,
    cn_stride,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    SPLIT_N: tl.constexpr,
    EVEN_K: tl.constexpr,
    ADD_INPUTS: tl.constexpr,
    CAST_TYPE: tl.constexpr,
):

@torch.inference_mode()
def bgmv_expand(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    add_inputs: bool = True,
    override_config: Optional[Dict[str, int]] = None,
):

Import

from lorax_server.utils.ops.bgmv_expand import bgmv_expand

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Input tensor of shape (batch_size, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection.
lora_b_weights torch.Tensor Yes LoRA B weight matrices of shape (lora_num, hidden_size, rank) or (lora_num, 1, hidden_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights.
output_tensor torch.Tensor Yes Output tensor of shape (batch_size, hidden_size), modified in-place. Must be contiguous.
lora_indices_tensor torch.Tensor Yes Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied.
add_inputs bool No Whether to add results to existing output values (default True) or overwrite them.
override_config Optional[Dict[str, int]] No Override for Triton grid configuration (BLOCK_N, SPLIT_N). Defaults to auto-selected config via get_lora_op_configs.

Outputs

Name Type Description
output_tensor torch.Tensor Modified in-place. Each row output_tensor[i] is updated with the result of inputs[i] multiplied by lora_b_weights[lora_indices[i]].

Usage Examples

# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_expand import bgmv_expand

# Apply LoRA B (expand) weights to low-rank activations
bgmv_expand(
    inputs=lora_a_output,           # (batch_size, rank)
    lora_b_weights=lora_b_weights,  # (num_loras, hidden_size, rank)
    output_tensor=output,           # (batch_size, hidden_size), modified in-place
    lora_indices_tensor=indices,    # (batch_size,), -1 for no LoRA
    add_inputs=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment