Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Metadata Kernels

From Leeroopedia


Knowledge Sources
Domains GPU_Kernels, Attention
Last Updated 2026-02-08 00:00 GMT

Overview

Collection of Triton kernels for GPU-accelerated metadata processing operations including block table format conversion, position/slot ID preparation, slot filtering, and input ID copying used in the LoRaX inference pipeline.

Description

This module provides a set of Triton kernels that accelerate metadata manipulation operations which would otherwise require expensive CPU-GPU synchronization or slow Python loops. The module includes a has_triton() helper that checks for Triton availability (currently CUDA-only due to a ROCm compatibility issue) and falls back to CPU implementations when Triton is unavailable.

The five Triton kernels and their corresponding Python wrapper functions are:

triton_block_tables_to_ragged: Converts padded 2D block tables (batch x max_blocks) into a ragged/flat 1D format compatible with FlashInfer. For each batch entry, it copies the valid block indices (determined by cumulative sequence lengths) from the padded table into contiguous segments of the ragged output. The CPU fallback iterates with Python loops.

triton_block_tables_to_padded: The inverse operation, converting a ragged 1D block table back into padded 2D format. Each batch entry's contiguous segment is copied into its row of the padded output table.

triton_copy_next_input_ids_inplace: Copies newly generated token IDs into the all_input_ids buffer during autoregressive decoding. Handles variable numbers of accepted tokens per request (for speculative decoding) using cu_accepted_ids for indexing. Includes a decode_mask that prevents writing for requests still in the prefill phase.

triton_prepare_position_slot_ids: Computes position IDs and slot indices for each token in a batch. Position IDs are computed as cache_length + token_offset, and slot indices as slot_start + cache_length + token_offset. These are needed for rotary position embeddings and KV cache slot addressing.

triton_slots_filtering: Filters and rearranges slot indices from a source layout into a destination layout based on start offsets and cumulative slot counts. Used for reorganizing slot assignments during batch processing.

All kernels use a 2D grid: axis 0 tiles over the data dimension (using configurable BLOCK_SIZE), and axis 1 iterates over the batch dimension.

Usage

These kernels are invoked during inference batch preparation to avoid CPU-GPU synchronization points. They are called by the model's metadata preparation routines before attention computation to set up block tables (for paged KV cache), position embeddings, slot assignments, and to update input ID buffers after token generation.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/metadata_kernels.py
  • Lines: 1-329

Signature

def has_triton() -> bool:

def block_tables_to_padded(
    max_blocks: int,
    cu_seqlen: torch.Tensor,
    block_tables: torch.Tensor,
    block_tables_ragged: torch.Tensor,
):

def block_tables_to_ragged(
    *,
    block_tables: torch.Tensor,
    input_lengths: List[int],
    cache_lengths: List[int],
    input_lengths_tensor: torch.Tensor,
    cache_lengths_tensor: torch.Tensor,
    max_current_length: int,
) -> torch.Tensor:

def copy_next_input_ids_inplace(
    max_next_input_ids: int,
    all_input_ids: torch.Tensor,
    cache_lengths: torch.Tensor,
    input_lengths: torch.Tensor,
    prompt_lengths: torch.Tensor,
    next_input_ids: torch.Tensor,
    cu_accepted_ids: torch.Tensor,
):

def prepare_position_slot_ids(
    max_input_length: int,
    cache_lengths: torch.Tensor,
    cu_seqlen: torch.Tensor,
    cu_slots: torch.Tensor,
    position_ids: torch.Tensor,
    slot_indices: torch.Tensor,
):

def slots_filtering(
    max_slots: int,
    slots: torch.Tensor,
    filtered_slots: torch.Tensor,
    cu_slots: torch.Tensor,
    slots_start: torch.Tensor,
):

@triton.jit
def triton_block_tables_to_ragged(cu_seqlen_ptr, block_tables_ptr,
    block_tables_ragged_ptr, stride_block_tables, BLOCK_SIZE: "tl.constexpr"):

@triton.jit
def triton_block_tables_to_padded(cu_seqlen_ptr, block_tables_ptr,
    block_tables_ragged_ptr, stride_block_tables, BLOCK_SIZE: "tl.constexpr"):

@triton.jit
def triton_copy_next_input_ids_inplace(all_input_ids_ptr, cache_lengths_ptr,
    input_lengths_ptr, prompt_lengths_ptr, next_input_ids_ptr,
    cu_accepted_ids_ptr, stride_all_input_ids, BLOCK_SIZE: "tl.constexpr"):

@triton.jit
def triton_prepare_position_slot_ids(cache_lengths_ptr, cu_seqlen_ptr,
    cu_slots_ptr, position_ids_ptr, slot_indices_ptr, BLOCK_SIZE: "tl.constexpr"):

@triton.jit
def triton_slots_filtering(slots_ptr, filtered_slots_ptr, slots_start_ptr,
    cu_slots_ptr, BLOCK_SIZE: "tl.constexpr"):

Import

from lorax_server.models.metadata_kernels import (
    has_triton,
    block_tables_to_padded,
    block_tables_to_ragged,
    copy_next_input_ids_inplace,
    prepare_position_slot_ids,
    slots_filtering,
)

I/O Contract

Inputs (block_tables_to_ragged)

Name Type Required Description
block_tables torch.Tensor Yes Padded 2D block table of shape (batch, max_blocks), dtype int32.
input_lengths List[int] Yes List of input lengths per request.
cache_lengths List[int] Yes List of cache lengths per request.
input_lengths_tensor torch.Tensor Yes Tensor version of input_lengths for cumulative sum computation.
cache_lengths_tensor torch.Tensor Yes Tensor version of cache_lengths.
max_current_length int Yes Maximum total length (input + cache) across the batch.

Inputs (prepare_position_slot_ids)

Name Type Required Description
max_input_length int Yes Maximum input length, used to size the Triton grid.
cache_lengths torch.Tensor Yes Cache lengths per request, used as position offset.
cu_seqlen torch.Tensor Yes Cumulative sequence lengths for indexing into packed tensors.
cu_slots torch.Tensor Yes Cumulative slot counts for KV cache slot addressing.
position_ids torch.Tensor Yes Output tensor for computed position IDs, modified in-place.
slot_indices torch.Tensor Yes Output tensor for computed slot indices, modified in-place.

Outputs

Name Type Description
block_tables_ragged torch.Tensor (from block_tables_to_ragged) Flat 1D ragged block table for FlashInfer compatibility.
position_ids torch.Tensor (from prepare_position_slot_ids) Computed position IDs for rotary embeddings, modified in-place.
slot_indices torch.Tensor (from prepare_position_slot_ids) Computed KV cache slot indices, modified in-place.

Usage Examples

from lorax_server.models.metadata_kernels import (
    block_tables_to_ragged,
    prepare_position_slot_ids,
)

# Convert padded block tables to ragged format for FlashInfer
ragged_tables = block_tables_to_ragged(
    block_tables=block_tables,
    input_lengths=input_lengths,
    cache_lengths=cache_lengths,
    input_lengths_tensor=input_lengths_tensor,
    cache_lengths_tensor=cache_lengths_tensor,
    max_current_length=max_len,
)

# Prepare position and slot IDs on GPU
prepare_position_slot_ids(
    max_input_length=max_input_len,
    cache_lengths=cache_lengths_tensor,
    cu_seqlen=cu_seqlen,
    cu_slots=cu_slots,
    position_ids=position_ids,
    slot_indices=slot_indices,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment