Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax SGMV Shrink Kernel

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Triton kernel implementing the Segmented Gather Matrix-Vector multiply (SGMV) shrink operation for applying LoRA A (down-projection) weights to variable-length sequences during multi-tenant LoRA adapter inference.

Description

The SGMV shrink kernel computes a segmented batched matrix multiplication between hidden-dimension activations and per-request LoRA A weight matrices, projecting from hidden dimension space down to rank space for variable-length sequences. For each batch b and token m in that batch's segment, it computes output[start_b + m] = scaling * input[start_b + m] * lora_A[idx[b]] where idx[b] is the LoRA adapter index (or skips computation if idx[b] == -1).

The kernel uses GroupGEMM combined with SPLIT-K parallelism, as described in the source: "The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, introducing SPLIT-K can improve performance." The kernel employs a 3D grid: axis 0 covers combined M (sequence) and N (rank) tile indices, axis 1 covers SPLIT_K partitions, and axis 2 iterates over batches. With fixed block sizes (BLOCK_M=32, BLOCK_N=16, BLOCK_K=32, SPLIT_K=8), partial results from each K-split are combined using tl.atomic_add when SPLIT_K > 1. The scaling factor is applied before storing. The @libentry() decorator reduces Triton launch overhead. Based on the Punica paper (Chen et al., 2023).

Usage

This kernel is invoked during LoRA adapter inference in the prefill phase when applying the LoRA A matrix (shrink/down-projection) to variable-length input sequences. It is the first step of the LoRA computation during prefill: projecting the model's hidden activations into the low-rank space. The result is then fed into the SGMV expand kernel for the LoRA B up-projection.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/ops/sgmv_shrink.py
  • Lines: 1-190

Signature

@libentry()
@triton.jit
def _sgmv_shrink_kernel(
    input_ptr,
    lora_ptr,
    out_ptr,
    N,
    K,
    b_seq_start_loc,
    seq_lens,
    lora_indices,
    scaling,
    xm_stride,
    xk_stride,
    l0_stride,
    lora_k_stride,
    lora_n_stride,
    cm_stride,
    cn_stride,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    EVEN_K: tl.constexpr,
    SPLIT_K: tl.constexpr,
):

@torch.inference_mode()
def sgmv_shrink(
    inputs: torch.Tensor,
    lora_a_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    scaling: float,
):

Import

from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Input tensor of shape (total_tokens, hidden_size), dtype float16/bfloat16. The hidden activations packed across all sequences. Must match lora_a_weights dtype.
lora_a_weights torch.Tensor Yes LoRA A weight matrices of shape (lora_num, rank, hidden_size) or (lora_num, 1, rank, hidden_size), dtype float16/bfloat16. The per-adapter down-projection weights.
output_tensor torch.Tensor Yes Output tensor of shape (total_tokens, rank), modified in-place. Must be contiguous.
b_seq_start_loc torch.Tensor Yes Tensor of shape (batch_size,) containing cumulative sequence start positions. E.g., for sequence lengths [4, 6], this would be [0, 4].
seq_len_tensor torch.Tensor Yes Tensor of shape (batch_size,) containing the sequence length for each batch entry.
lora_indices_tensor torch.Tensor Yes Tensor of shape (batch_size,) mapping each batch entry to its LoRA adapter index. A value of -1 means no LoRA is applied.
batches int Yes Number of batch entries (sequences).
max_seq_length int Yes Maximum sequence length in the batch, used to size the Triton grid.
scaling float Yes Scaling factor applied to the result. Typically the LoRA alpha/rank scaling.

Outputs

Name Type Description
output_tensor torch.Tensor Modified in-place. For each batch b and token m within that batch's segment, the corresponding output row contains scaling * (input dot lora_a_weights[lora_indices[b]]), the low-rank projection.

Usage Examples

# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink

# Apply LoRA A (shrink) weights during prefill with variable-length sequences
sgmv_shrink(
    inputs=hidden_states,           # (total_tokens, hidden_size)
    lora_a_weights=lora_a_weights,  # (num_loras, rank, hidden_size)
    output_tensor=lora_a_output,    # (total_tokens, rank), modified in-place
    b_seq_start_loc=seq_starts,     # (batch_size,)
    seq_len_tensor=seq_lens,        # (batch_size,)
    lora_indices_tensor=indices,    # (batch_size,)
    batches=batch_size,
    max_seq_length=max_seq_len,
    scaling=alpha / rank,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment