Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Predibase Lorax Get Model Factory

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Model_Serving
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for selecting and instantiating transformer model architectures provided by the LoRAX model factory.

Description

The get_model() function is the central model factory in LoRAX. It reads the model's HuggingFace configuration, determines the model_type, resolves the dtype, and instantiates the correct Model subclass. It supports 20+ architectures including Llama, Mistral, Mixtral, Gemma, Qwen, Phi, GPT-2, BLOOM, T5, and vision-language models (LLaVA-NeXT, Mllama).

Usage

Called once during server startup by the Python gRPC shard process. Not called directly by end users. Invoked internally when the launcher starts a new shard.

Code Reference

Source Location

  • Repository: LoRAX
  • File: server/lorax_server/models/__init__.py
  • Lines: 44-394

Signature

def get_model(
    model_id: str,
    adapter_id: str,
    revision: Optional[str],
    sharded: bool,
    quantize: Optional[str],
    compile: bool,
    dtype: Optional[str],
    trust_remote_code: bool,
    source: str,
    adapter_source: str,
    merge_adapter_weights: bool,
    embedding_dim: Optional[int] = None,
) -> Model:
    """
    Factory function that instantiates the appropriate Model subclass
    based on the model_type from the HuggingFace config.

    Args:
        model_id: HuggingFace model ID or local path
        adapter_id: Initial adapter to load with model
        revision: Model revision/commit hash
        sharded: Whether to shard across GPUs
        quantize: Quantization method (bitsandbytes/gptq/awq/eetq/hqq/fp8)
        compile: Whether to use torch.compile
        dtype: Target dtype (float16/bfloat16)
        trust_remote_code: Allow custom model code
        source: Model source ("hub" or "s3")
        adapter_source: Adapter source type
        merge_adapter_weights: Merge adapter into base at load time
        embedding_dim: Custom embedding dimension (for embedding models)
    """

Import

from lorax_server.models import get_model

I/O Contract

Inputs

Name Type Required Description
model_id str Yes HuggingFace model ID or local path
adapter_id str Yes Initial adapter to merge at load time
revision Optional[str] No Model revision/commit hash
sharded bool Yes Whether to shard model across GPUs
quantize Optional[str] No Quantization: bitsandbytes, gptq, awq, eetq, hqq, fp8
compile bool Yes Enable torch.compile optimization
dtype Optional[str] No Target dtype: float16, bfloat16
trust_remote_code bool Yes Allow execution of custom model code
source str Yes Model source: "hub" or "s3"
adapter_source str Yes Adapter source type
merge_adapter_weights bool Yes Merge adapter weights into base model
embedding_dim Optional[int] No Custom embedding dimension

Outputs

Name Type Description
model Model Instantiated Model subclass (FlashLlama, FlashMistral, etc.)

Usage Examples

Internal Server Usage

from lorax_server.models import get_model

# Called during shard initialization
model = get_model(
    model_id="mistralai/Mistral-7B-Instruct-v0.1",
    adapter_id="",
    revision=None,
    sharded=False,
    quantize=None,
    compile=False,
    dtype="float16",
    trust_remote_code=False,
    source="hub",
    adapter_source="hub",
    merge_adapter_weights=False,
)
# Returns FlashMistral instance ready for inference

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment