Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Causal LM

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the non-flash causal language model inference wrapper, including padded batch management and autoregressive token generation, serving as the base class for standard (non-flash-attention) decoder-only model inference in the LoRax server.

Description

This module provides the standard (non-paged-attention) causal language model wrapper using HuggingFace AutoModelForCausalLM. It manages padded batching with explicit attention masks and past key-value caching.

Key classes:

  • CausalLMBatch (extends Batch, dataclass) - Manages batch state for padded causal LM inference:
    • Tracks input_ids, attention_mask, position_ids, and past_key_values as core decoder state.
    • Maintains per-request metadata: next_token_choosers, stopping_criterias, prefix_offsets, read_offsets.
    • Includes adapter_meta (AdapterBatchMetadata) for LoRA adapter segment tracking.
    • from_pb - Constructs the batch from protobuf with tokenization, attention mask allocation (padded to max input length + padding right offset), and adapter index setup.
    • filter - Removes completed requests, slicing tensors and past key-value caches. Handles both standard key layout and BLOOM's transposed key layout (keys_head_dim_last flag).
    • concatenate - Merges multiple batches by padding past key-value caches to matching dimensions and concatenating along the batch axis, supporting heterogeneous batch sizes.
  • CausalLM (extends Model) - The model wrapper that:
    • Loads models via AutoModelForCausalLM.from_pretrained with support for bitsandbytes 8-bit quantization and multi-GPU device mapping.
    • Handles pad token ID resolution from model config or special tokens.
    • Disables dynamic adapter loading (dynamic_adapter_loading_enabled = False).
    • forward - Runs model forward pass with optional position IDs and adapter data, returning logits and updated past key values.
    • generate_token - Full autoregressive generation step: runs forward pass, applies token choosers, evaluates stopping criteria, manages prefill logprobs, updates batch state (input IDs, attention mask, position IDs, past key values), and returns generations with sharding support across world size.

Usage

CausalLM serves as the base inference wrapper for standard decoder-only models that do not use flash attention. It is used for models loaded through AutoModelForCausalLM and is also subclassed by GalacticaSharded for Galactica-specific behavior. For flash-attention-optimized models, the separate FlashCausalLM class is used instead.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/causal_lm.py
  • Lines: 1-751

Signature

@dataclass
class CausalLMBatch(Batch):
    batch_id: int
    requests: List[generate_pb2.Request]
    requests_idx_mapping: Dict[int, int]
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    position_ids: torch.Tensor
    past_key_values: Optional[List[Tuple]]
    all_input_ids: List[torch.Tensor]
    input_lengths: List[int]
    prefix_offsets: List[int]
    read_offsets: List[int]
    next_token_choosers: List[NextTokenChooser]
    stopping_criterias: List[StoppingCriteria]
    adapter_meta: AdapterBatchMetadata
    max_input_length: int
    padding_right_offset: int
    max_tokens: int
    keys_head_dim_last: bool = True

    @classmethod
    def from_pb(cls, pb, tokenizer, tokenizers, processor, config, dtype, device) -> "CausalLMBatch":
        ...
    def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
        ...
    @classmethod
    def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
        ...

class CausalLM(Model):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        compile: bool = False,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        ...
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        ...
    def forward(self, input_ids, attention_mask, position_ids, past_key_values=None, adapter_data=None) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
        ...
    def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
        ...

Import

from lorax_server.models.causal_lm import CausalLM, CausalLMBatch

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model identifier
revision Optional[str] No Model revision/commit hash
quantize Optional[str] No Quantization method (e.g., "bitsandbytes")
compile bool No Whether to compile the model (not supported, logged as skip)
dtype Optional[torch.dtype] No Model dtype (defaults to float16 on GPU, float32 on CPU)
trust_remote_code bool No Whether to trust remote code in model loading

Outputs

Name Type Description
generations List[Generation] Generated token information for each request in the batch
next_batch Optional[CausalLMBatch] Updated batch for next generation step (None if all requests complete)

Usage Examples

# Internal LoRax server usage
from lorax_server.models.causal_lm import CausalLM

# Instantiated by model registry for non-flash causal models
# model = CausalLM(model_id="gpt2", dtype=torch.float16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment