Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Medusa Model

From Leeroopedia
Revision as of 15:50, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_Medusa_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

The Medusa Model module (python/mlc_llm/model/medusa/medusa_model.py) implements the Medusa speculative decoding head architecture. Medusa is a technique that adds multiple lightweight prediction heads to a base language model, enabling parallel token prediction for faster inference. This module defines the model configuration, the residual block building block, and the main MedusaModel class.

Location

  • File: python/mlc_llm/model/medusa/medusa_model.py
  • Lines: 84
  • Module: mlc_llm.model.medusa

Key Components

MedusaConfig

A dataclass extending ConfigBase that holds configuration parameters for the Medusa model.

@dataclasses.dataclass
class MedusaConfig(ConfigBase):
    medusa_num_heads: int
    medusa_num_layers: int
    hidden_size: int
    vocab_size: int
    max_batch_size: int = 1
    tensor_parallel_shards: int = 1
    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
    prefill_chunk_size: int = -1
    context_window_size: int = -1

Fields:

Field Type Default Description
medusa_num_heads int required Number of Medusa prediction heads for parallel token prediction.
medusa_num_layers int required Number of residual block layers within each Medusa head.
hidden_size int required Dimensionality of the hidden state from the base model.
vocab_size int required Size of the vocabulary for the output linear projection.
max_batch_size int 1 Maximum batch size.
tensor_parallel_shards int 1 Number of tensor parallel shards.
kwargs Dict[str, Any] {} Additional keyword arguments.
prefill_chunk_size int -1 Unused; kept for compatibility with the compilation flow.
context_window_size int -1 Unused; kept for compatibility with the compilation flow.

ResBlock

A residual block with a SiLU (Sigmoid Linear Unit) activation function. This is the fundamental building block of each Medusa head.

class ResBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.act = nn.SiLU()

    def forward(self, x):
        return x + self.act(self.linear(x))

The block applies a linear transformation followed by SiLU activation, then adds the result to the input (residual connection). This preserves the input dimension while allowing each layer to refine the hidden representation.

MedusaModel

The main model class that composes multiple Medusa prediction heads.

Constructor

class MedusaModel(nn.Module):
    def __init__(self, config: MedusaConfig):
        self.hidden_size = config.hidden_size
        self.dtype = "float32"
        self.medusa_head = nn.ModuleList(
            [
                nn.ModuleList(
                    [ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)]
                    + [nn.Linear(config.hidden_size, config.vocab_size, bias=False)]
                )
                for _ in range(config.medusa_num_heads)
            ]
        )

The model creates a nested ModuleList structure:

  • Outer list: One entry per Medusa head (medusa_num_heads total).
  • Inner list: A sequence of medusa_num_layers residual blocks followed by a single linear projection from hidden_size to vocab_size (without bias).

This architecture means each Medusa head independently processes the base model's hidden states through its own chain of residual blocks before projecting to vocabulary logits.

get_default_spec

def get_default_spec(self):
    mod_spec = {
        "get_logits": {
            "hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype),
            "$": {
                "param_mode": "packed",
                "effect_mode": "none",
            },
        },
    }
    return nn.spec.ModuleSpec.from_raw(mod_spec, self)

Defines the TVM export specification for the model. The spec declares a single exported function get_logits that:

  • Accepts hidden_states with shape (batch_size, hidden_size).
  • Uses packed parameter mode (all parameters packed into a single argument).
  • Uses no effects (pure computation with no side effects like KV cache updates).

get_logits

def get_logits(self, hidden_states: nn.Tensor):
    logits = []
    for head in self.medusa_head:
        logits.append(head(hidden_states).astype("float32"))
    return logits

Runs each Medusa head independently on the same input hidden_states and collects the resulting logits. Each head's output is cast to float32 regardless of the model's internal dtype. The return value is a list of logit tensors, one per head.

to

def to(self, dtype: Optional[str] = None):
    super().to(dtype=dtype)
    if dtype is not None:
        self.dtype = dtype

Overrides the base to method to also update the instance's dtype attribute, which is used in the export spec and initialized to "float32".

Architecture Diagram

The overall architecture can be visualized as:

hidden_states (batch_size, hidden_size)
        |
        +---> Head 0: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_0
        |
        +---> Head 1: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_1
        |
        ...
        |
        +---> Head N: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_N

Each head predicts the next token at a different speculative position, enabling tree-based speculative decoding.

Key Design Decisions

  • Independent heads: Each Medusa head has its own parameters and processes hidden states independently, allowing diverse token predictions at different future positions.
  • Residual connections: The use of ResBlock with residual connections allows the heads to make incremental refinements to the base model's hidden states rather than learning the full transformation from scratch.
  • SiLU activation: SiLU (also known as Swish) is used for its smooth non-linearity, consistent with the activation used in many modern LLM architectures.
  • No bias in output projection: The final linear layer omits bias, matching standard LLM head conventions.
  • float32 logits: Output logits are always cast to float32 for numerical stability during token selection, regardless of the model's compute dtype.

Dependencies

  • dataclasses -- standard library for the dataclass decorator
  • tvm.relax.frontend.nn -- TVM Relax neural network primitives (nn.Module, nn.Linear, nn.SiLU, nn.ModuleList, nn.spec)
  • mlc_llm.support.logging -- logging utilities
  • mlc_llm.support.config.ConfigBase -- base configuration class

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment