Implementation:Predibase Lorax Metadata Kernels
| Knowledge Sources | |
|---|---|
| Domains | GPU_Kernels, Attention |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Collection of Triton kernels for GPU-accelerated metadata processing operations including block table format conversion, position/slot ID preparation, slot filtering, and input ID copying used in the LoRaX inference pipeline.
Description
This module provides a set of Triton kernels that accelerate metadata manipulation operations which would otherwise require expensive CPU-GPU synchronization or slow Python loops. The module includes a has_triton() helper that checks for Triton availability (currently CUDA-only due to a ROCm compatibility issue) and falls back to CPU implementations when Triton is unavailable.
The five Triton kernels and their corresponding Python wrapper functions are:
triton_block_tables_to_ragged: Converts padded 2D block tables (batch x max_blocks) into a ragged/flat 1D format compatible with FlashInfer. For each batch entry, it copies the valid block indices (determined by cumulative sequence lengths) from the padded table into contiguous segments of the ragged output. The CPU fallback iterates with Python loops.
triton_block_tables_to_padded: The inverse operation, converting a ragged 1D block table back into padded 2D format. Each batch entry's contiguous segment is copied into its row of the padded output table.
triton_copy_next_input_ids_inplace: Copies newly generated token IDs into the all_input_ids buffer during autoregressive decoding. Handles variable numbers of accepted tokens per request (for speculative decoding) using cu_accepted_ids for indexing. Includes a decode_mask that prevents writing for requests still in the prefill phase.
triton_prepare_position_slot_ids: Computes position IDs and slot indices for each token in a batch. Position IDs are computed as cache_length + token_offset, and slot indices as slot_start + cache_length + token_offset. These are needed for rotary position embeddings and KV cache slot addressing.
triton_slots_filtering: Filters and rearranges slot indices from a source layout into a destination layout based on start offsets and cumulative slot counts. Used for reorganizing slot assignments during batch processing.
All kernels use a 2D grid: axis 0 tiles over the data dimension (using configurable BLOCK_SIZE), and axis 1 iterates over the batch dimension.
Usage
These kernels are invoked during inference batch preparation to avoid CPU-GPU synchronization points. They are called by the model's metadata preparation routines before attention computation to set up block tables (for paged KV cache), position embeddings, slot assignments, and to update input ID buffers after token generation.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/metadata_kernels.py - Lines: 1-329
Signature
def has_triton() -> bool:
def block_tables_to_padded(
max_blocks: int,
cu_seqlen: torch.Tensor,
block_tables: torch.Tensor,
block_tables_ragged: torch.Tensor,
):
def block_tables_to_ragged(
*,
block_tables: torch.Tensor,
input_lengths: List[int],
cache_lengths: List[int],
input_lengths_tensor: torch.Tensor,
cache_lengths_tensor: torch.Tensor,
max_current_length: int,
) -> torch.Tensor:
def copy_next_input_ids_inplace(
max_next_input_ids: int,
all_input_ids: torch.Tensor,
cache_lengths: torch.Tensor,
input_lengths: torch.Tensor,
prompt_lengths: torch.Tensor,
next_input_ids: torch.Tensor,
cu_accepted_ids: torch.Tensor,
):
def prepare_position_slot_ids(
max_input_length: int,
cache_lengths: torch.Tensor,
cu_seqlen: torch.Tensor,
cu_slots: torch.Tensor,
position_ids: torch.Tensor,
slot_indices: torch.Tensor,
):
def slots_filtering(
max_slots: int,
slots: torch.Tensor,
filtered_slots: torch.Tensor,
cu_slots: torch.Tensor,
slots_start: torch.Tensor,
):
@triton.jit
def triton_block_tables_to_ragged(cu_seqlen_ptr, block_tables_ptr,
block_tables_ragged_ptr, stride_block_tables, BLOCK_SIZE: "tl.constexpr"):
@triton.jit
def triton_block_tables_to_padded(cu_seqlen_ptr, block_tables_ptr,
block_tables_ragged_ptr, stride_block_tables, BLOCK_SIZE: "tl.constexpr"):
@triton.jit
def triton_copy_next_input_ids_inplace(all_input_ids_ptr, cache_lengths_ptr,
input_lengths_ptr, prompt_lengths_ptr, next_input_ids_ptr,
cu_accepted_ids_ptr, stride_all_input_ids, BLOCK_SIZE: "tl.constexpr"):
@triton.jit
def triton_prepare_position_slot_ids(cache_lengths_ptr, cu_seqlen_ptr,
cu_slots_ptr, position_ids_ptr, slot_indices_ptr, BLOCK_SIZE: "tl.constexpr"):
@triton.jit
def triton_slots_filtering(slots_ptr, filtered_slots_ptr, slots_start_ptr,
cu_slots_ptr, BLOCK_SIZE: "tl.constexpr"):
Import
from lorax_server.models.metadata_kernels import (
has_triton,
block_tables_to_padded,
block_tables_to_ragged,
copy_next_input_ids_inplace,
prepare_position_slot_ids,
slots_filtering,
)
I/O Contract
Inputs (block_tables_to_ragged)
| Name | Type | Required | Description |
|---|---|---|---|
| block_tables | torch.Tensor | Yes | Padded 2D block table of shape (batch, max_blocks), dtype int32. |
| input_lengths | List[int] | Yes | List of input lengths per request. |
| cache_lengths | List[int] | Yes | List of cache lengths per request. |
| input_lengths_tensor | torch.Tensor | Yes | Tensor version of input_lengths for cumulative sum computation. |
| cache_lengths_tensor | torch.Tensor | Yes | Tensor version of cache_lengths. |
| max_current_length | int | Yes | Maximum total length (input + cache) across the batch. |
Inputs (prepare_position_slot_ids)
| Name | Type | Required | Description |
|---|---|---|---|
| max_input_length | int | Yes | Maximum input length, used to size the Triton grid. |
| cache_lengths | torch.Tensor | Yes | Cache lengths per request, used as position offset. |
| cu_seqlen | torch.Tensor | Yes | Cumulative sequence lengths for indexing into packed tensors. |
| cu_slots | torch.Tensor | Yes | Cumulative slot counts for KV cache slot addressing. |
| position_ids | torch.Tensor | Yes | Output tensor for computed position IDs, modified in-place. |
| slot_indices | torch.Tensor | Yes | Output tensor for computed slot indices, modified in-place. |
Outputs
| Name | Type | Description |
|---|---|---|
| block_tables_ragged | torch.Tensor | (from block_tables_to_ragged) Flat 1D ragged block table for FlashInfer compatibility. |
| position_ids | torch.Tensor | (from prepare_position_slot_ids) Computed position IDs for rotary embeddings, modified in-place. |
| slot_indices | torch.Tensor | (from prepare_position_slot_ids) Computed KV cache slot indices, modified in-place. |
Usage Examples
from lorax_server.models.metadata_kernels import (
block_tables_to_ragged,
prepare_position_slot_ids,
)
# Convert padded block tables to ragged format for FlashInfer
ragged_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_len,
)
# Prepare position and slot IDs on GPU
prepare_position_slot_ids(
max_input_length=max_input_len,
cache_lengths=cache_lengths_tensor,
cu_seqlen=cu_seqlen,
cu_slots=cu_slots,
position_ids=position_ids,
slot_indices=slot_indices,
)