Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Medusa LoRA Adapter

From Leeroopedia


Knowledge Sources
Domains Speculative_Decoding, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Implements a combined Medusa + LoRA adapter that enables speculative decoding via Medusa heads while simultaneously applying LoRA fine-tuning to the base model, with unified configuration, weight loading, and batch type registration.

Description

This module provides a composite adapter that combines Medusa speculative decoding with LoRA parameter-efficient fine-tuning. This allows a single adapter to both accelerate inference through speculative decoding and customize the model's behavior through LoRA.

MedusaLoraModuleMap (dataclass): A simple container that holds two separate module maps: lora_module_map for the LoRA weight mappings and medusa_module_map for the Medusa head weight mappings.

MedusaLoraConfig (AdapterConfig): Configuration dataclass that wraps both a LoraConfig and a MedusaConfig. Key methods:

  • map_weights_for_model(): Delegates to both sub-configs to map adapter weights into their respective module maps, returning a MedusaLoraModuleMap that bundles both.
  • load_batched_adapter_weights(): Loads LoRA weights and Medusa weights independently from their respective module maps, then combines them into a MedusaLoraWeights instance.
  • load() (classmethod): Factory that creates the config from an adapter ID and config dict, loading the LoRA config from the Hugging Face Hub and the Medusa config from the provided dictionary.

MedusaLoraWeights (AdapterWeights): Holds both lora_weights (LoraWeights) and medusa_weights (MedusaWeights). Key features:

  • get_batch_types(): Returns both BatchLoraWeights and BatchMedusaWeights, enabling the adapter system to create separate batch objects for each adapter type during inference. This is how the composite adapter participates in both the LoRA and Medusa batching pipelines.
  • speculative_tokens: Delegates to the Medusa weights to expose the number of speculative tokens.
  • load() (classmethod): Simple factory that wraps the two weight objects.

Usage

This adapter is used when a user provides an adapter that contains both LoRA weights and Medusa heads. During loading, MedusaLoraConfig.load() is called with the adapter ID and config. The system then loads both weight types and creates a MedusaLoraWeights instance. During batched inference, the adapter system calls get_batch_types() to determine that this adapter contributes to both BatchLoraWeights and BatchMedusaWeights batches, enabling simultaneous LoRA application in the transformer layers and Medusa speculative decoding at the output.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/adapters/medusa_lora.py
  • Lines: 1-93

Signature

@dataclass
class MedusaLoraModuleMap:
    lora_module_map: ModuleMap
    medusa_module_map: ModuleMap

@dataclass
class MedusaLoraConfig(AdapterConfig):
    lora_config: LoraConfig
    medusa_config: MedusaConfig
    def map_weights_for_model(self, adapter_weights, weight_names) -> Tuple[MedusaLoraModuleMap, Set[str]]
    def load_batched_adapter_weights(self, model, module_map, layer_type, unused_weight_names, dynamic)
    @classmethod
    def load(cls, adapter_id: str, config: dict, api_token: str) -> "MedusaLoraConfig"

class MedusaLoraWeights(AdapterWeights):
    def __init__(self, lora_weights: LoraWeights, medusa_weights: MedusaWeights)
    @classmethod
    def get_batch_types(cls) -> List[Type[BatchAdapterWeights]]
    @property
    def speculative_tokens(self) -> int
    @classmethod
    def load(cls, lora_weights, medusa_weights) -> Optional[AdapterWeights]

Import

from lorax_server.adapters.medusa_lora import MedusaLoraConfig, MedusaLoraWeights

I/O Contract

Inputs

Name Type Required Description
adapter_id str Yes (for load) Hugging Face adapter identifier for loading LoRA config
config dict Yes (for load) Configuration dict containing Medusa parameters (medusa_num_heads, medusa_num_layers, version)
api_token str Yes (for load) Hugging Face API token for accessing the adapter
lora_weights LoraWeights Yes Loaded LoRA adapter weights
medusa_weights MedusaWeights Yes Loaded Medusa adapter weights
model Model Yes (for load_batched) The base model instance
module_map MedusaLoraModuleMap Yes (for load_batched) Combined module map for both adapter types

Outputs

Name Type Description
MedusaLoraWeights AdapterWeights Combined adapter weights containing both LoRA and Medusa components
batch_types List[Type[BatchAdapterWeights]] Returns [BatchLoraWeights, BatchMedusaWeights] for dual-pathway batching
speculative_tokens int Number of speculative tokens from the Medusa config

Usage Examples

# Loading a combined Medusa + LoRA adapter
from lorax_server.adapters.medusa_lora import MedusaLoraConfig

config = MedusaLoraConfig.load(
    adapter_id="org/model-medusa-lora",
    config={
        "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
        "medusa_num_heads": 3,
        "medusa_num_layers": 1,
        "version": 2,
    },
    api_token="hf_xxx",
)

# The weights are loaded through the adapter pipeline
weights = config.load_batched_adapter_weights(
    model=model,
    module_map=module_map,
    layer_type="default",
    unused_weight_names=set(),
    dynamic=True,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment