Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax BGMV Shrink Kernel

From Leeroopedia
Revision as of 16:20, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_BGMV_Shrink_Kernel.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains GPU_Kernels, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Triton kernel implementing the Batched Gather Matrix-Vector multiply (BGMV) shrink operation for applying LoRA A (down-projection) weights during multi-tenant LoRA adapter inference.

Description

The BGMV shrink kernel computes a batched matrix-vector product between hidden-dimension activations and per-request LoRA A weight matrices, projecting from hidden dimension space down to rank space. Mathematically, for each request i in the batch, it computes output[i] = scaling * input[i] * lora_A[idx[i]] where idx[i] is the LoRA adapter index (or skips computation if idx[i] == -1).

This kernel is based on the Punica paper (Chen et al., 2023) for multi-tenant LoRA serving. It uses a GroupGEMV approach with a configurable SPLIT_K parameter that partitions the input/hidden dimension K across thread blocks. When SPLIT_K > 1, partial results from each split are combined using tl.atomic_add to avoid race conditions. The kernel accumulates in float32 precision and applies a scaling factor before storing results. The output dimension N (rank) is rounded to the next power of 2 for the block size (BLOCK_N).

Usage

This kernel is invoked during LoRA adapter inference in the decode phase (single-token generation) when applying the LoRA A matrix (shrink/down-projection). It is the first step of the LoRA computation: projecting the model's hidden activations into the low-rank space. The result is then fed into the BGMV expand kernel for the LoRA B up-projection. The BGMV variant handles batched single-token requests.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/ops/bgmv_shrink.py
  • Lines: 1-149

Signature

@triton.jit
def _bgmv_shrink_kernel(
    input_ptr,
    lora_ptr,
    out_ptr,
    N,
    K,
    lora_indices,
    scaling,
    xm_stride,
    xk_stride,
    l0_stride,
    lora_k_stride,
    lora_n_stride,
    cm_stride,
    cn_stride,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    SPLIT_K: tl.constexpr,
):

@torch.inference_mode()
def bgmv_shrink(
    inputs: torch.Tensor,
    lora_a_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    scaling: float = 1.0,
    override_config: Optional[Dict[str, int]] = None,
):

Import

from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Input tensor of shape (batch_size, hidden_size), dtype float16/bfloat16. The hidden activations to be projected into low-rank space. Must match lora_a_weights dtype.
lora_a_weights torch.Tensor Yes LoRA A weight matrices of shape (lora_num, rank, hidden_size) or (lora_num, 1, rank, hidden_size), dtype float16/bfloat16. The per-adapter down-projection weights.
output_tensor torch.Tensor Yes Output tensor of shape (batch_size, rank), modified in-place. Must be contiguous.
lora_indices_tensor torch.Tensor Yes Tensor of shape (batch_size,) mapping each request to its LoRA adapter index. A value of -1 means no LoRA is applied.
scaling float No Scaling factor applied to the result (default 1.0). Typically the LoRA alpha/rank scaling.
override_config Optional[Dict[str, int]] No Override for Triton grid configuration (BLOCK_K, SPLIT_K). Defaults to auto-selected config via get_lora_op_configs.

Outputs

Name Type Description
output_tensor torch.Tensor Modified in-place. Each row output_tensor[i] contains scaling * (inputs[i] dot lora_a_weights[lora_indices[i]]), which is the low-rank projection of the input.

Usage Examples

# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.bgmv_shrink import bgmv_shrink

# Apply LoRA A (shrink) weights to project hidden activations to rank space
bgmv_shrink(
    inputs=hidden_states,           # (batch_size, hidden_size)
    lora_a_weights=lora_a_weights,  # (num_loras, rank, hidden_size)
    output_tensor=lora_a_output,    # (batch_size, rank), modified in-place
    lora_indices_tensor=indices,    # (batch_size,), -1 for no LoRA
    scaling=alpha / rank,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment