Implementation:Predibase Lorax Galactica
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements the Galactica scientific language model wrapper with custom tokenization handling for special scientific sequences (DNA, SMILES, amino acids), extending the CausalLM base class with sharded OPT model loading and tensor-parallel inference.
Description
This module provides specialized handling for the Galactica model family, which uses custom token sequences for scientific notation. It includes custom tokenization preprocessing and a sharded model loader using the OPT architecture.
Key functions:
- escape_custom_split_sequence - Applies custom splitting to text for Galactica's tokenization, inserting split markers between characters within special sequences like
[START_DNA]...[END_DNA],[START_SMILES]...[END_SMILES], and[START_AMINO]...[END_AMINO].
- _insert_split_marker - Regex callback that inserts the
SPLIT_MARKERbetween each character of a matched special sequence, enabling per-character tokenization of scientific notation.
Key classes:
- GalacticaCausalLMBatch (extends
CausalLMBatch) - Overridesfrom_pbto applyescape_custom_split_sequenceto each input before tokenization, ensuring proper handling of scientific sequence tokens.
- GalacticaSharded (extends
CausalLM) - A tensor-parallel model wrapper that:- Initializes distributed processing with
initialize_torch_distributed. - Loads the model configuration with
tp_parallel=Trueand applies quantization settings. - Uses the custom
OPTForCausalLMimplementation (fromopt_modeling) instead of HuggingFace's AutoModel, loaded from safetensors weight files. - Overrides
batch_typeto returnGalacticaCausalLMBatch. - Overrides
decodeto keep special tokens (does not skip them) since Galactica uses them for custom parsing rules. - Overrides
forwardto call the model without position IDs (OPT handles positions internally).
- Initializes distributed processing with
Usage
GalacticaSharded is instantiated by the LoRax model registry when loading Galactica models (e.g., "facebook/galactica-6.7b"). The custom tokenization ensures that scientific notation sequences are properly split for the model's vocabulary, while the sharded loading enables multi-GPU inference through tensor parallelism.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/models/galactica.py - Lines: 1-227
Signature
def escape_custom_split_sequence(text):
...
class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
...
class GalacticaSharded(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
compile: bool = False,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
...
@property
def batch_type(self) -> Type[CausalLMBatch]:
...
def decode(self, generated_ids: List[int]) -> str:
...
def forward(self, input_ids, attention_mask, position_ids, past_key_values=None):
...
Import
from lorax_server.models.galactica import GalacticaSharded
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model identifier for a Galactica model |
| revision | Optional[str] | No | Model revision/commit hash |
| quantize | Optional[str] | No | Quantization method (e.g., "bitsandbytes") |
| compile | bool | No | Whether to compile the model (not supported, logged as skip) |
| dtype | Optional[torch.dtype] | No | Model dtype (defaults to float16 on GPU, float32 on CPU) |
| trust_remote_code | bool | No | Whether to trust remote code in model loading |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Next-token logits over the vocabulary |
| past_key_values | List[Tuple[torch.Tensor, torch.Tensor]] | Updated KV cache for autoregressive decoding |
Usage Examples
# Internal LoRax server usage
from lorax_server.models.galactica import GalacticaSharded
# Instantiated by model registry for Galactica models
# model = GalacticaSharded(model_id="facebook/galactica-6.7b")