Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Galactica

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the Galactica scientific language model wrapper with custom tokenization handling for special scientific sequences (DNA, SMILES, amino acids), extending the CausalLM base class with sharded OPT model loading and tensor-parallel inference.

Description

This module provides specialized handling for the Galactica model family, which uses custom token sequences for scientific notation. It includes custom tokenization preprocessing and a sharded model loader using the OPT architecture.

Key functions:

  • escape_custom_split_sequence - Applies custom splitting to text for Galactica's tokenization, inserting split markers between characters within special sequences like [START_DNA]...[END_DNA], [START_SMILES]...[END_SMILES], and [START_AMINO]...[END_AMINO].
  • _insert_split_marker - Regex callback that inserts the SPLIT_MARKER between each character of a matched special sequence, enabling per-character tokenization of scientific notation.

Key classes:

  • GalacticaCausalLMBatch (extends CausalLMBatch) - Overrides from_pb to apply escape_custom_split_sequence to each input before tokenization, ensuring proper handling of scientific sequence tokens.
  • GalacticaSharded (extends CausalLM) - A tensor-parallel model wrapper that:
    • Initializes distributed processing with initialize_torch_distributed.
    • Loads the model configuration with tp_parallel=True and applies quantization settings.
    • Uses the custom OPTForCausalLM implementation (from opt_modeling) instead of HuggingFace's AutoModel, loaded from safetensors weight files.
    • Overrides batch_type to return GalacticaCausalLMBatch.
    • Overrides decode to keep special tokens (does not skip them) since Galactica uses them for custom parsing rules.
    • Overrides forward to call the model without position IDs (OPT handles positions internally).

Usage

GalacticaSharded is instantiated by the LoRax model registry when loading Galactica models (e.g., "facebook/galactica-6.7b"). The custom tokenization ensures that scientific notation sequences are properly split for the model's vocabulary, while the sharded loading enables multi-GPU inference through tensor parallelism.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/models/galactica.py
  • Lines: 1-227

Signature

def escape_custom_split_sequence(text):
    ...

class GalacticaCausalLMBatch(CausalLMBatch):
    @classmethod
    def from_pb(
        cls,
        pb: generate_pb2.Batch,
        tokenizer: PreTrainedTokenizerBase,
        tokenizers: TokenizerManager,
        processor,
        config,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "GalacticaCausalLMBatch":
        ...

class GalacticaSharded(CausalLM):
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
        compile: bool = False,
        dtype: Optional[torch.dtype] = None,
        trust_remote_code: bool = False,
    ):
        ...
    @property
    def batch_type(self) -> Type[CausalLMBatch]:
        ...
    def decode(self, generated_ids: List[int]) -> str:
        ...
    def forward(self, input_ids, attention_mask, position_ids, past_key_values=None):
        ...

Import

from lorax_server.models.galactica import GalacticaSharded

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model identifier for a Galactica model
revision Optional[str] No Model revision/commit hash
quantize Optional[str] No Quantization method (e.g., "bitsandbytes")
compile bool No Whether to compile the model (not supported, logged as skip)
dtype Optional[torch.dtype] No Model dtype (defaults to float16 on GPU, float32 on CPU)
trust_remote_code bool No Whether to trust remote code in model loading

Outputs

Name Type Description
logits torch.Tensor Next-token logits over the vocabulary
past_key_values List[Tuple[torch.Tensor, torch.Tensor]] Updated KV cache for autoregressive decoding

Usage Examples

# Internal LoRax server usage
from lorax_server.models.galactica import GalacticaSharded

# Instantiated by model registry for Galactica models
# model = GalacticaSharded(model_id="facebook/galactica-6.7b")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment