Implementation:Mlc ai Mlc llm Medusa Model
Overview
The Medusa Model module (python/mlc_llm/model/medusa/medusa_model.py) implements the Medusa speculative decoding head architecture. Medusa is a technique that adds multiple lightweight prediction heads to a base language model, enabling parallel token prediction for faster inference. This module defines the model configuration, the residual block building block, and the main MedusaModel class.
Location
- File:
python/mlc_llm/model/medusa/medusa_model.py - Lines: 84
- Module:
mlc_llm.model.medusa
Key Components
MedusaConfig
A dataclass extending ConfigBase that holds configuration parameters for the Medusa model.
@dataclasses.dataclass
class MedusaConfig(ConfigBase):
medusa_num_heads: int
medusa_num_layers: int
hidden_size: int
vocab_size: int
max_batch_size: int = 1
tensor_parallel_shards: int = 1
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
prefill_chunk_size: int = -1
context_window_size: int = -1
Fields:
| Field | Type | Default | Description |
|---|---|---|---|
medusa_num_heads |
int |
required | Number of Medusa prediction heads for parallel token prediction. |
medusa_num_layers |
int |
required | Number of residual block layers within each Medusa head. |
hidden_size |
int |
required | Dimensionality of the hidden state from the base model. |
vocab_size |
int |
required | Size of the vocabulary for the output linear projection. |
max_batch_size |
int |
1 | Maximum batch size. |
tensor_parallel_shards |
int |
1 | Number of tensor parallel shards. |
kwargs |
Dict[str, Any] |
{} |
Additional keyword arguments. |
prefill_chunk_size |
int |
-1 | Unused; kept for compatibility with the compilation flow. |
context_window_size |
int |
-1 | Unused; kept for compatibility with the compilation flow. |
ResBlock
A residual block with a SiLU (Sigmoid Linear Unit) activation function. This is the fundamental building block of each Medusa head.
class ResBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear = nn.Linear(hidden_size, hidden_size)
self.act = nn.SiLU()
def forward(self, x):
return x + self.act(self.linear(x))
The block applies a linear transformation followed by SiLU activation, then adds the result to the input (residual connection). This preserves the input dimension while allowing each layer to refine the hidden representation.
MedusaModel
The main model class that composes multiple Medusa prediction heads.
Constructor
class MedusaModel(nn.Module):
def __init__(self, config: MedusaConfig):
self.hidden_size = config.hidden_size
self.dtype = "float32"
self.medusa_head = nn.ModuleList(
[
nn.ModuleList(
[ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)]
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False)]
)
for _ in range(config.medusa_num_heads)
]
)
The model creates a nested ModuleList structure:
- Outer list: One entry per Medusa head (
medusa_num_headstotal). - Inner list: A sequence of
medusa_num_layersresidual blocks followed by a single linear projection fromhidden_sizetovocab_size(without bias).
This architecture means each Medusa head independently processes the base model's hidden states through its own chain of residual blocks before projecting to vocabulary logits.
get_default_spec
def get_default_spec(self):
mod_spec = {
"get_logits": {
"hidden_states": nn.spec.Tensor(["batch_size", self.hidden_size], self.dtype),
"$": {
"param_mode": "packed",
"effect_mode": "none",
},
},
}
return nn.spec.ModuleSpec.from_raw(mod_spec, self)
Defines the TVM export specification for the model. The spec declares a single exported function get_logits that:
- Accepts
hidden_stateswith shape(batch_size, hidden_size). - Uses packed parameter mode (all parameters packed into a single argument).
- Uses no effects (pure computation with no side effects like KV cache updates).
get_logits
def get_logits(self, hidden_states: nn.Tensor):
logits = []
for head in self.medusa_head:
logits.append(head(hidden_states).astype("float32"))
return logits
Runs each Medusa head independently on the same input hidden_states and collects the resulting logits. Each head's output is cast to float32 regardless of the model's internal dtype. The return value is a list of logit tensors, one per head.
to
def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
if dtype is not None:
self.dtype = dtype
Overrides the base to method to also update the instance's dtype attribute, which is used in the export spec and initialized to "float32".
Architecture Diagram
The overall architecture can be visualized as:
hidden_states (batch_size, hidden_size)
|
+---> Head 0: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_0
|
+---> Head 1: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_1
|
...
|
+---> Head N: ResBlock -> ResBlock -> ... -> Linear(hidden_size, vocab_size) -> logits_N
Each head predicts the next token at a different speculative position, enabling tree-based speculative decoding.
Key Design Decisions
- Independent heads: Each Medusa head has its own parameters and processes hidden states independently, allowing diverse token predictions at different future positions.
- Residual connections: The use of
ResBlockwith residual connections allows the heads to make incremental refinements to the base model's hidden states rather than learning the full transformation from scratch. - SiLU activation: SiLU (also known as Swish) is used for its smooth non-linearity, consistent with the activation used in many modern LLM architectures.
- No bias in output projection: The final linear layer omits bias, matching standard LLM head conventions.
- float32 logits: Output logits are always cast to float32 for numerical stability during token selection, regardless of the model's compute dtype.
Dependencies
dataclasses-- standard library for the dataclass decoratortvm.relax.frontend.nn-- TVM Relax neural network primitives (nn.Module,nn.Linear,nn.SiLU,nn.ModuleList,nn.spec)mlc_llm.support.logging-- logging utilitiesmlc_llm.support.config.ConfigBase-- base configuration class