Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax SGMV Expand Kernel

From Leeroopedia
Revision as of 16:21, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Predibase_Lorax_SGMV_Expand_Kernel.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains GPU_Kernels, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Triton kernel implementing the Segmented Gather Matrix-Vector multiply (SGMV) expand operation for applying LoRA B (up-projection) weights to variable-length sequences during multi-tenant LoRA adapter inference.

Description

The SGMV expand kernel computes a segmented batched matrix multiplication between low-rank intermediate activations and per-request LoRA B weight matrices, projecting from rank space back to hidden dimension space. Unlike the BGMV variant which handles single-token requests, SGMV handles variable-length sequences where each batch entry may contain multiple tokens (as in the prefill phase). Mathematically, for each batch b and each token m in that batch's segment, it computes output[start_b + m] += input[start_b + m] * lora_B[idx[b]].

The kernel is based on a GroupGEMM (Group General Matrix Multiply) approach from the Punica paper (Chen et al., 2023). It uses 2D tiling with fixed block sizes (BLOCK_M=32, BLOCK_N=32, BLOCK_K=16) and a 2D grid: axis 0 covers the M (sequence) and N (hidden) tile combinations, while axis 1 iterates over batches. The kernel uses tl.dot for tile-level matrix multiplication instead of the element-wise multiply-and-reduce used in BGMV. It supports mixed-precision casting and optional additive accumulation into the output. The @libentry() decorator is applied to reduce Triton kernel launch overhead.

Usage

This kernel is invoked during LoRA adapter inference in the prefill phase when applying the LoRA B matrix (expand/up-projection) to variable-length input sequences. The SGMV variant handles multiple tokens per request, making it suitable for the initial prompt processing stage. It is called from the Punica kernel orchestrator after SGMV shrink has projected activations into rank space.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/utils/ops/sgmv_expand.py
  • Lines: 1-192

Signature

@libentry()
@triton.jit
def _sgmv_expand_kernel(
    input_ptr,
    lora_ptr,
    out_ptr,
    N,
    K,
    b_seq_start_loc,
    seq_lens,
    lora_indices,
    xm_stride,
    xk_stride,
    l0_stride,
    lora_k_stride,
    lora_n_stride,
    cm_stride,
    cn_stride,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    EVEN_K: tl.constexpr,
    ADD_INPUTS: tl.constexpr,
    CAST_TYPE: tl.constexpr,
):

@torch.inference_mode()
def sgmv_expand(
    inputs: torch.Tensor,
    lora_b_weights: torch.Tensor,
    output_tensor: torch.Tensor,
    b_seq_start_loc: torch.Tensor,
    seq_len_tensor: torch.Tensor,
    lora_indices_tensor: torch.Tensor,
    batches: int,
    max_seq_length: int,
    add_inputs: bool = False,
):

Import

from lorax_server.utils.ops.sgmv_expand import sgmv_expand

I/O Contract

Inputs

Name Type Required Description
inputs torch.Tensor Yes Input tensor of shape (total_tokens, rank), dtype float16/bfloat16/float32. The low-rank intermediate activations from the LoRA A projection, packed across all sequences.
lora_b_weights torch.Tensor Yes LoRA B weight matrices of shape (lora_num, hidden_size, rank) or (lora_num, 1, hidden_size, rank), dtype float16/bfloat16. The per-adapter up-projection weights.
output_tensor torch.Tensor Yes Output tensor of shape (total_tokens, hidden_size), modified in-place. Must be contiguous.
b_seq_start_loc torch.Tensor Yes Tensor of shape (batch_size,) containing cumulative sequence start positions. E.g., for sequence lengths [4, 6], this would be [0, 4].
seq_len_tensor torch.Tensor Yes Tensor of shape (batch_size,) containing the sequence length for each batch entry.
lora_indices_tensor torch.Tensor Yes Tensor of shape (batch_size,) mapping each batch entry to its LoRA adapter index. A value of -1 means no LoRA is applied.
batches int Yes Number of batch entries (sequences).
max_seq_length int Yes Maximum sequence length in the batch, used to size the Triton grid.
add_inputs bool No Whether to add results to existing output values (default False) or overwrite them.

Outputs

Name Type Description
output_tensor torch.Tensor Modified in-place. For each batch b and token m within that batch's segment, the corresponding output row is updated with the matrix product of the input and lora_b_weights[lora_indices[b]].

Usage Examples

# Called internally by LoRA kernel orchestrator (punica.py)
from lorax_server.utils.ops.sgmv_expand import sgmv_expand

# Apply LoRA B (expand) weights during prefill with variable-length sequences
sgmv_expand(
    inputs=lora_a_output,           # (total_tokens, rank)
    lora_b_weights=lora_b_weights,  # (num_loras, hidden_size, rank)
    output_tensor=output,           # (total_tokens, hidden_size), modified in-place
    b_seq_start_loc=seq_starts,     # (batch_size,)
    seq_len_tensor=seq_lens,        # (batch_size,)
    lora_indices_tensor=indices,    # (batch_size,)
    batches=batch_size,
    max_seq_length=max_seq_len,
    add_inputs=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment