Implementation:Predibase Lorax Medusa LoRA Adapter
| Knowledge Sources | |
|---|---|
| Domains | Speculative_Decoding, LoRA |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a combined Medusa + LoRA adapter that enables speculative decoding via Medusa heads while simultaneously applying LoRA fine-tuning to the base model, with unified configuration, weight loading, and batch type registration.
Description
This module provides a composite adapter that combines Medusa speculative decoding with LoRA parameter-efficient fine-tuning. This allows a single adapter to both accelerate inference through speculative decoding and customize the model's behavior through LoRA.
MedusaLoraModuleMap (dataclass): A simple container that holds two separate module maps: lora_module_map for the LoRA weight mappings and medusa_module_map for the Medusa head weight mappings.
MedusaLoraConfig (AdapterConfig): Configuration dataclass that wraps both a LoraConfig and a MedusaConfig. Key methods:
- map_weights_for_model(): Delegates to both sub-configs to map adapter weights into their respective module maps, returning a MedusaLoraModuleMap that bundles both.
- load_batched_adapter_weights(): Loads LoRA weights and Medusa weights independently from their respective module maps, then combines them into a MedusaLoraWeights instance.
- load() (classmethod): Factory that creates the config from an adapter ID and config dict, loading the LoRA config from the Hugging Face Hub and the Medusa config from the provided dictionary.
MedusaLoraWeights (AdapterWeights): Holds both lora_weights (LoraWeights) and medusa_weights (MedusaWeights). Key features:
- get_batch_types(): Returns both BatchLoraWeights and BatchMedusaWeights, enabling the adapter system to create separate batch objects for each adapter type during inference. This is how the composite adapter participates in both the LoRA and Medusa batching pipelines.
- speculative_tokens: Delegates to the Medusa weights to expose the number of speculative tokens.
- load() (classmethod): Simple factory that wraps the two weight objects.
Usage
This adapter is used when a user provides an adapter that contains both LoRA weights and Medusa heads. During loading, MedusaLoraConfig.load() is called with the adapter ID and config. The system then loads both weight types and creates a MedusaLoraWeights instance. During batched inference, the adapter system calls get_batch_types() to determine that this adapter contributes to both BatchLoraWeights and BatchMedusaWeights batches, enabling simultaneous LoRA application in the transformer layers and Medusa speculative decoding at the output.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File:
server/lorax_server/adapters/medusa_lora.py - Lines: 1-93
Signature
@dataclass
class MedusaLoraModuleMap:
lora_module_map: ModuleMap
medusa_module_map: ModuleMap
@dataclass
class MedusaLoraConfig(AdapterConfig):
lora_config: LoraConfig
medusa_config: MedusaConfig
def map_weights_for_model(self, adapter_weights, weight_names) -> Tuple[MedusaLoraModuleMap, Set[str]]
def load_batched_adapter_weights(self, model, module_map, layer_type, unused_weight_names, dynamic)
@classmethod
def load(cls, adapter_id: str, config: dict, api_token: str) -> "MedusaLoraConfig"
class MedusaLoraWeights(AdapterWeights):
def __init__(self, lora_weights: LoraWeights, medusa_weights: MedusaWeights)
@classmethod
def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]
@property
def speculative_tokens(self) -> int
@classmethod
def load(cls, lora_weights, medusa_weights) -> Optional[AdapterWeights]
Import
from lorax_server.adapters.medusa_lora import MedusaLoraConfig, MedusaLoraWeights
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| adapter_id | str | Yes (for load) | Hugging Face adapter identifier for loading LoRA config |
| config | dict | Yes (for load) | Configuration dict containing Medusa parameters (medusa_num_heads, medusa_num_layers, version) |
| api_token | str | Yes (for load) | Hugging Face API token for accessing the adapter |
| lora_weights | LoraWeights | Yes | Loaded LoRA adapter weights |
| medusa_weights | MedusaWeights | Yes | Loaded Medusa adapter weights |
| model | Model | Yes (for load_batched) | The base model instance |
| module_map | MedusaLoraModuleMap | Yes (for load_batched) | Combined module map for both adapter types |
Outputs
| Name | Type | Description |
|---|---|---|
| MedusaLoraWeights | AdapterWeights | Combined adapter weights containing both LoRA and Medusa components |
| batch_types | List[Type[BatchAdapterWeights]] | Returns [BatchLoraWeights, BatchMedusaWeights] for dual-pathway batching |
| speculative_tokens | int | Number of speculative tokens from the Medusa config |
Usage Examples
# Loading a combined Medusa + LoRA adapter
from lorax_server.adapters.medusa_lora import MedusaLoraConfig
config = MedusaLoraConfig.load(
adapter_id="org/model-medusa-lora",
config={
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
"medusa_num_heads": 3,
"medusa_num_layers": 1,
"version": 2,
},
api_token="hf_xxx",
)
# The weights are loaded through the adapter pipeline
weights = config.load_batched_adapter_weights(
model=model,
module_map=module_map,
layer_type="default",
unused_weight_names=set(),
dynamic=True,
)