Implementation:Predibase Lorax SGMV Shrink Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton kernel implementing the Segmented Gather Matrix-Vector multiply (SGMV) shrink operation for applying LoRA A (down-projection) weights to variable-length sequences during multi-tenant LoRA adapter inference.
Description
The SGMV shrink kernel computes a segmented batched matrix multiplication between hidden-dimension activations and per-request LoRA A weight matrices, projecting from hidden dimension space down to rank space for variable-length sequences. For each batch b and token m in that batch's segment, it computes output[start_b + m] = scaling * input[start_b + m] * lora_A[idx[b]] where idx[b] is the LoRA adapter index (or skips computation if idx[b] == -1).
The kernel uses GroupGEMM combined with SPLIT-K parallelism, as described in the source: "The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, introducing SPLIT-K can improve performance." The kernel employs a 3D grid: axis 0 covers combined M (sequence) and N (rank) tile indices, axis 1 covers SPLIT_K partitions, and axis 2 iterates over batches. With fixed block sizes (BLOCK_M=32, BLOCK_N=16, BLOCK_K=32, SPLIT_K=8), partial results from each K-split are combined using tl.atomic_add when SPLIT_K > 1. The scaling factor is applied before storing. The @libentry() decorator reduces Triton launch overhead. Based on the Punica paper (Chen et al., 2023).
Usage
This kernel is invoked during LoRA adapter inference in the prefill phase when applying the LoRA A matrix (shrink/down-projection) to variable-length input sequences. It is the first step of the LoRA computation during prefill: projecting the model's hidden activations into the low-rank space. The result is then fed into the SGMV expand kernel for the LoRA B up-projection.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/sgmv_shrink.py - Lines: 1-190
Signature
@libentry()
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
scaling,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
):
@torch.inference_mode()
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
scaling: float,
):
Import
from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Input tensor of shape (total_tokens, hidden_size), dtype float16/bfloat16. The hidden activations packed across all sequences. Must match lora_a_weights dtype. |
| lora_a_weights | torch.Tensor | Yes | LoRA A weight matrices of shape (lora_num, rank, hidden_size) or (lora_num, 1, rank, hidden_size), dtype float16/bfloat16. The per-adapter down-projection weights. |
| output_tensor | torch.Tensor | Yes | Output tensor of shape (total_tokens, rank), modified in-place. Must be contiguous. |
| b_seq_start_loc | torch.Tensor | Yes | Tensor of shape (batch_size,) containing cumulative sequence start positions. E.g., for sequence lengths [4, 6], this would be [0, 4]. |
| seq_len_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) containing the sequence length for each batch entry. |
| lora_indices_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) mapping each batch entry to its LoRA adapter index. A value of -1 means no LoRA is applied. |
| batches | int | Yes | Number of batch entries (sequences). |
| max_seq_length | int | Yes | Maximum sequence length in the batch, used to size the Triton grid. |
| scaling | float | Yes | Scaling factor applied to the result. Typically the LoRA alpha/rank scaling. |
Outputs
| Name | Type | Description |
|---|---|---|
| output_tensor | torch.Tensor | Modified in-place. For each batch b and token m within that batch's segment, the corresponding output row contains scaling * (input dot lora_a_weights[lora_indices[b]]), the low-rank projection. |
Usage Examples
# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.sgmv_shrink import sgmv_shrink
# Apply LoRA A (shrink) weights during prefill with variable-length sequences
sgmv_shrink(
inputs=hidden_states, # (total_tokens, hidden_size)
lora_a_weights=lora_a_weights, # (num_loras, rank, hidden_size)
output_tensor=lora_a_output, # (total_tokens, rank), modified in-place
b_seq_start_loc=seq_starts, # (batch_size,)
seq_len_tensor=seq_lens, # (batch_size,)
lora_indices_tensor=indices, # (batch_size,)
batches=batch_size,
max_seq_length=max_seq_len,
scaling=alpha / rank,
)