Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax Medusa Adapter

From Leeroopedia


Knowledge Sources
Domains Speculative_Decoding, LoRA
Last Updated 2026-02-08 00:00 GMT

Overview

Implements the Medusa speculative decoding adapter system, including configuration, model architecture (V1 and V2), weight management, and batched multi-adapter inference with segment-based routing.

Description

This module provides the full Medusa adapter implementation for speculative decoding in LoRAX. Medusa adds multiple prediction heads on top of a base language model to generate draft tokens in parallel, reducing autoregressive decoding latency.

MedusaConfig (AdapterConfig): Configuration dataclass with medusa_num_heads, medusa_num_layers, and version fields. The load_batched_adapter_weights() method handles both static (server-initialization) and dynamic (per-request) adapter loading, with validation that dynamic loading requires a pre-initialized default Medusa adapter and matching speculative token counts. Sets the global _MEDUSA_ENABLED flag on first static load.

ResBlock (nn.Module): A residual block consisting of a FastLinear layer followed by SiLU activation with a skip connection: x + SiLU(linear(x)).

MedusaHead (nn.Module): A single Medusa prediction head composed of a stack of ResBlock layers (medusa_num_layers deep) followed by a final FastLinear output projection.

MedusaV1 (nn.Module): The original Medusa architecture with independent MedusaHead modules for each speculative position. Forward pass applies each head independently and stacks the outputs.

MedusaV2 (nn.Module): An optimized architecture that uses TensorParallelColumnLinear.load_multi() to combine all heads into a single fused linear layer (requires medusa_num_layers == 1). Supports tensor parallelism via process_group with all-gather for multi-GPU inference. Includes a batch size guard (LORAX_SPECULATION_MAX_BATCH_SIZE) that skips speculation for large batches. For multi-adapter scenarios, uses segmented_matmul via the Punica kernel.

MedusaModel (nn.Module): Factory wrapper that selects V1 or V2 based on configuration (V1 if medusa_num_layers > 1 or the output weight tensor exists, otherwise V2).

MedusaWeights (AdapterWeights): Manages the Medusa model weights and wraps them in an InMemoryWeights container. Exposes speculative_tokens property.

MedusaSegments (dataclass): Holds per-segment weight and bias tensors (w, b) along with segment start/end indices (s_start, s_end) for multi-adapter batched inference.

BatchMedusaWeights (BatchAdapterWeights): Manages batched inference across multiple Medusa adapters. The load() classmethod constructs segment mappings from adapter metadata, handling default adapter fallback and segment merging. The __call__ method dispatches to the default Medusa model or falls back to the plain lm_head.

Usage

Medusa adapters are loaded either at server startup (as the default speculative decoding adapter) or dynamically per request (V2 only). During inference, BatchMedusaWeights is constructed from the batch metadata and called after the final hidden states to produce both standard logits and speculative logits for tree-based verification.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/adapters/medusa.py
  • Lines: 1-340

Signature

@dataclass
class MedusaConfig(AdapterConfig):
    medusa_num_heads: int
    medusa_num_layers: int
    version: int
    @classmethod
    def load(cls, config: dict) -> "MedusaConfig"
    def load_batched_adapter_weights(self, model, module_map, layer_type, unused_weight_names, dynamic)

class ResBlock(torch.nn.Module):
    def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights)
    def forward(self, x)

class MedusaHead(torch.nn.Module):
    def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights)
    def forward(self, x)

class MedusaV1(torch.nn.Module):
    def __init__(self, config: MedusaConfig, weights: AbstractWeights)
    def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None)

class MedusaV2(torch.nn.Module):
    def __init__(self, config: MedusaConfig, weights: AbstractWeights)
    def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None)

class MedusaModel(torch.nn.Module):
    def __init__(self, config: MedusaConfig, weights: AbstractWeights)
    def forward(self, x, lm_head, segments: Optional[MedusaSegments] = None)

class MedusaWeights(AdapterWeights):
    def __init__(self, config, module_map, model)
    @classmethod
    def load(cls, config, model, module_map, layer_type, unused_weight_names)

@dataclass
class BatchMedusaWeights(BatchAdapterWeights):
    def has_adapter(self, adapter_index: int) -> bool
    def __call__(self, x, lm_head)
    @classmethod
    def load(cls, adapter_weights, meta, layer_name, prefill, prefill_head_indices)

Import

from lorax_server.adapters.medusa import MedusaConfig, MedusaWeights, BatchMedusaWeights, MedusaModel

I/O Contract

Inputs

Name Type Required Description
x torch.Tensor Yes Hidden states from the base model, shape (batch_size, hidden_dim)
lm_head callable Yes Language model head (linear projection to vocabulary)
segments MedusaSegments or None No Segment metadata for multi-adapter batched inference
config dict Yes (for load) Configuration dictionary with medusa_num_heads, medusa_num_layers, version

Outputs

Name Type Description
logits torch.Tensor Standard next-token logits from the lm_head
speculative_logits torch.Tensor or None Speculative logits from Medusa heads, shape (batch_size, n_medusa_heads, vocab_size), or None if skipped

Usage Examples

# Loading a Medusa adapter configuration
from lorax_server.adapters.medusa import MedusaConfig, MedusaWeights

config = MedusaConfig.load({
    "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
    "medusa_num_heads": 3,
    "medusa_num_layers": 1,
    "version": 2,
})

# During inference via BatchMedusaWeights
logits, speculative_logits = batch_medusa_weights(hidden_states, lm_head)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment