Implementation:Predibase Lorax Get Model Factory
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Model_Serving |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for selecting and instantiating transformer model architectures provided by the LoRAX model factory.
Description
The get_model() function is the central model factory in LoRAX. It reads the model's HuggingFace configuration, determines the model_type, resolves the dtype, and instantiates the correct Model subclass. It supports 20+ architectures including Llama, Mistral, Mixtral, Gemma, Qwen, Phi, GPT-2, BLOOM, T5, and vision-language models (LLaVA-NeXT, Mllama).
Usage
Called once during server startup by the Python gRPC shard process. Not called directly by end users. Invoked internally when the launcher starts a new shard.
Code Reference
Source Location
- Repository: LoRAX
- File: server/lorax_server/models/__init__.py
- Lines: 44-394
Signature
def get_model(
model_id: str,
adapter_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
compile: bool,
dtype: Optional[str],
trust_remote_code: bool,
source: str,
adapter_source: str,
merge_adapter_weights: bool,
embedding_dim: Optional[int] = None,
) -> Model:
"""
Factory function that instantiates the appropriate Model subclass
based on the model_type from the HuggingFace config.
Args:
model_id: HuggingFace model ID or local path
adapter_id: Initial adapter to load with model
revision: Model revision/commit hash
sharded: Whether to shard across GPUs
quantize: Quantization method (bitsandbytes/gptq/awq/eetq/hqq/fp8)
compile: Whether to use torch.compile
dtype: Target dtype (float16/bfloat16)
trust_remote_code: Allow custom model code
source: Model source ("hub" or "s3")
adapter_source: Adapter source type
merge_adapter_weights: Merge adapter into base at load time
embedding_dim: Custom embedding dimension (for embedding models)
"""
Import
from lorax_server.models import get_model
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_id | str | Yes | HuggingFace model ID or local path |
| adapter_id | str | Yes | Initial adapter to merge at load time |
| revision | Optional[str] | No | Model revision/commit hash |
| sharded | bool | Yes | Whether to shard model across GPUs |
| quantize | Optional[str] | No | Quantization: bitsandbytes, gptq, awq, eetq, hqq, fp8 |
| compile | bool | Yes | Enable torch.compile optimization |
| dtype | Optional[str] | No | Target dtype: float16, bfloat16 |
| trust_remote_code | bool | Yes | Allow execution of custom model code |
| source | str | Yes | Model source: "hub" or "s3" |
| adapter_source | str | Yes | Adapter source type |
| merge_adapter_weights | bool | Yes | Merge adapter weights into base model |
| embedding_dim | Optional[int] | No | Custom embedding dimension |
Outputs
| Name | Type | Description |
|---|---|---|
| model | Model | Instantiated Model subclass (FlashLlama, FlashMistral, etc.) |
Usage Examples
Internal Server Usage
from lorax_server.models import get_model
# Called during shard initialization
model = get_model(
model_id="mistralai/Mistral-7B-Instruct-v0.1",
adapter_id="",
revision=None,
sharded=False,
quantize=None,
compile=False,
dtype="float16",
trust_remote_code=False,
source="hub",
adapter_source="hub",
merge_adapter_weights=False,
)
# Returns FlashMistral instance ready for inference