Implementation:Predibase Lorax SGMV Expand Kernel
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Triton kernel implementing the Segmented Gather Matrix-Vector multiply (SGMV) expand operation for applying LoRA B (up-projection) weights to variable-length sequences during multi-tenant LoRA adapter inference.
Description
The SGMV expand kernel computes a segmented batched matrix multiplication between low-rank intermediate activations and per-request LoRA B weight matrices, projecting from rank space back to hidden dimension space. Unlike the BGMV variant which handles single-token requests, SGMV handles variable-length sequences where each batch entry may contain multiple tokens (as in the prefill phase). Mathematically, for each batch b and each token m in that batch's segment, it computes output[start_b + m] += input[start_b + m] * lora_B[idx[b]].
The kernel is based on a GroupGEMM (Group General Matrix Multiply) approach from the Punica paper (Chen et al., 2023). It uses 2D tiling with fixed block sizes (BLOCK_M=32, BLOCK_N=32, BLOCK_K=16) and a 2D grid: axis 0 covers the M (sequence) and N (hidden) tile combinations, while axis 1 iterates over batches. The kernel uses tl.dot for tile-level matrix multiplication instead of the element-wise multiply-and-reduce used in BGMV. It supports mixed-precision casting and optional additive accumulation into the output. The @libentry() decorator is applied to reduce Triton kernel launch overhead.
Usage
This kernel is invoked during LoRA adapter inference in the prefill phase when applying the LoRA B matrix (expand/up-projection) to variable-length input sequences. The SGMV variant handles multiple tokens per request, making it suitable for the initial prompt processing stage. It is called from the Punica kernel orchestrator after SGMV shrink has projected activations into rank space.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/utils/ops/sgmv_expand.py - Lines: 1-192
Signature
@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
N,
K,
b_seq_start_loc,
seq_lens,
lora_indices,
xm_stride,
xk_stride,
l0_stride,
lora_k_stride,
lora_n_stride,
cm_stride,
cn_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
):
@torch.inference_mode()
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
add_inputs: bool = False,
):
Import
from lorax_server.utils.ops.sgmv_expand import sgmv_expand
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| inputs | torch.Tensor | Yes | Input tensor of shape (total_tokens, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection, packed across all sequences. |
| lora_b_weights | torch.Tensor | Yes | LoRA B weight matrices of shape (lora_num, hidden_size, rank) or (lora_num, 1, hidden_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights. |
| output_tensor | torch.Tensor | Yes | Output tensor of shape (total_tokens, hidden_size), modified in-place. Must be contiguous. |
| b_seq_start_loc | torch.Tensor | Yes | Tensor of shape (batch_size,) containing cumulative sequence start positions. E.g., for sequence lengths [4, 6], this would be [0, 4]. |
| seq_len_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) containing the sequence length for each batch entry. |
| lora_indices_tensor | torch.Tensor | Yes | Tensor of shape (batch_size,) mapping each batch entry to its LoRA adapter index. A value of -1 means no LoRA is applied. |
| batches | int | Yes | Number of batch entries (sequences). |
| max_seq_length | int | Yes | Maximum sequence length in the batch, used to size the Triton grid. |
| add_inputs | bool | No | Whether to add results to existing output values (default False) or overwrite them. |
Outputs
| Name | Type | Description |
|---|---|---|
| output_tensor | torch.Tensor | Modified in-place. For each batch b and token m within that batch's segment, the corresponding output row is updated with the matrix product of the input and lora_b_weights[lora_indices[b]]. |
Usage Examples
# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.sgmv_expand import sgmv_expand
# Apply LoRA B (expand) weights during prefill with variable-length sequences
sgmv_expand(
inputs=lora_a_output, # (total_tokens, rank)
lora_b_weights=lora_b_weights, # (num_loras, hidden_size, rank)
output_tensor=output, # (total_tokens, hidden_size), modified in-place
b_seq_start_loc=seq_starts, # (batch_size,)
seq_len_tensor=seq_lens, # (batch_size,)
lora_indices_tensor=indices, # (batch_size,)
batches=batch_size,
max_seq_length=max_seq_len,
add_inputs=True,
)