Implementation:Predibase Lorax Causal LM
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the non-flash causal language model inference wrapper, including padded batch management and autoregressive token generation, serving as the base class for standard (non-flash-attention) decoder-only model inference in the LoRax server.
Description
This module provides the standard (non-paged-attention) causal language model wrapper using HuggingFace AutoModelForCausalLM. It manages padded batching with explicit attention masks and past key-value caching.
Key classes:
- CausalLMBatch (extends
Batch, dataclass) - Manages batch state for padded causal LM inference:- Tracks
input_ids,attention_mask,position_ids, andpast_key_valuesas core decoder state. - Maintains per-request metadata:
next_token_choosers,stopping_criterias,prefix_offsets,read_offsets. - Includes
adapter_meta(AdapterBatchMetadata) for LoRA adapter segment tracking. from_pb- Constructs the batch from protobuf with tokenization, attention mask allocation (padded to max input length + padding right offset), and adapter index setup.filter- Removes completed requests, slicing tensors and past key-value caches. Handles both standard key layout and BLOOM's transposed key layout (keys_head_dim_lastflag).concatenate- Merges multiple batches by padding past key-value caches to matching dimensions and concatenating along the batch axis, supporting heterogeneous batch sizes.
- Tracks
- CausalLM (extends
Model) - The model wrapper that:- Loads models via
AutoModelForCausalLM.from_pretrainedwith support for bitsandbytes 8-bit quantization and multi-GPU device mapping. - Handles pad token ID resolution from model config or special tokens.
- Disables dynamic adapter loading (
dynamic_adapter_loading_enabled = False). forward- Runs model forward pass with optional position IDs and adapter data, returning logits and updated past key values.generate_token- Full autoregressive generation step: runs forward pass, applies token choosers, evaluates stopping criteria, manages prefill logprobs, updates batch state (input IDs, attention mask, position IDs, past key values), and returns generations with sharding support across world size.
- Loads models via
Usage
CausalLM serves as the base inference wrapper for standard decoder-only models that do not use flash attention. It is used for models loaded through AutoModelForCausalLM and is also subclassed by GalacticaSharded for Galactica-specific behavior. For flash-attention-optimized models, the separate FlashCausalLM class is used instead.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/causal_lm.py - Lines: 1-751
Signature
@dataclass
class CausalLMBatch(Batch):
batch_id: int
requests: List[generate_pb2.Request]
requests_idx_mapping: Dict[int, int]
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
past_key_values: Optional[List[Tuple]]
all_input_ids: List[torch.Tensor]
input_lengths: List[int]
prefix_offsets: List[int]
read_offsets: List[int]
next_token_choosers: List[NextTokenChooser]
stopping_criterias: List[StoppingCriteria]
adapter_meta: AdapterBatchMetadata
max_input_length: int
padding_right_offset: int
max_tokens: int
keys_head_dim_last: bool = True
@classmethod
def from_pb(cls, pb, tokenizer, tokenizers, processor, config, dtype, device) -> "CausalLMBatch":
...
def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]:
...
@classmethod
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
...
class CausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
compile: bool = False,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
...
@property
def batch_type(self) -> Type[CausalLMBatch]:
...
def forward(self, input_ids, attention_mask, position_ids, past_key_values=None, adapter_data=None) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
...
def generate_token(self, batch: CausalLMBatch) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
...
Import
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model identifier |
| revision | Optional[str] | No | Model revision/commit hash |
| quantize | Optional[str] | No | Quantization method (e.g., "bitsandbytes") |
| compile | bool | No | Whether to compile the model (not supported, logged as skip) |
| dtype | Optional[torch.dtype] | No | Model dtype (defaults to float16 on GPU, float32 on CPU) |
| trust_remote_code | bool | No | Whether to trust remote code in model loading |
Outputs
| Name | Type | Description |
|---|---|---|
| generations | List[Generation] | Generated token information for each request in the batch |
| next_batch | Optional[CausalLMBatch] | Updated batch for next generation step (None if all requests complete) |
Usage Examples
# Internal LoRax server usage
from lorax_server.models.causal_lm import CausalLM
# Instantiated by model registry for non-flash causal models
# model = CausalLM(model_id="gpt2", dtype=torch.float16)