Implementation:Mlc ai Mlc llm Llama Model
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LLM |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Implements the Llama architecture for causal language modeling within the MLC LLM framework, supporting grouped-query attention, RoPE with llama3-style scaling, pipeline parallelism, disaggregated inference, and optional tied word embeddings.
Description
This module provides the TVM Relax-based implementation of Meta's Llama model architecture (Llama 2, Llama 3, Llama 3.1, Llama 3.2, and compatible models). It is one of the most feature-rich model implementations in MLC LLM, serving as the reference architecture for many derivative models.
Key architectural features include:
- Grouped-query attention (GQA): Supports configurable numbers of query and key-value heads via
num_attention_headsandnum_key_value_heads. - Fused QKV projection: A single
qkv_projlinear layer produces all query, key, and value projections in one operation. - SiLU-gated MLP: Uses a fused
gate_up_projcombining gate and up projections, followed by SiLU activation and adown_proj. - RoPE with llama3 scaling: Supports standard RoPE and llama3-specific rope scaling with configurable
factor,low_freq_factor,high_freq_factor, andoriginal_max_position_embeddings. - Tied word embeddings: Optionally shares the embedding table with the LM head via a custom
LlamaEmbeddingclass that transposes weights for the output projection. - Pipeline parallelism: Supports partitioning layers across pipeline stages with
pipeline_parallel_stagesand automatic boundary insertion. - Disaggregated inference: Provides dedicated methods for extracting last hidden states (
prefill_to_last_hidden_states,batch_forward_to_last_hidden_states, etc.) enabling disaggregated prefill/decode workflows. - RMSNorm: Uses RMSNorm for input and post-attention normalization.
The module defines LlamaModel (embedding + decoder layers with pipeline partitioning + final RMSNorm), wrapped by LlamaForCausalLM which adds the LM head and the full suite of inference methods.
Usage
Use this module when compiling Llama family models (Llama 2, Llama 3, Llama 3.1, Llama 3.2, CodeLlama, TinyLlama, SmolLM, and other LlamaForCausalLM-architecture models) for deployment with MLC LLM.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/llama/llama_model.py
Signature
@dataclasses.dataclass
class LlamaConfig(ConfigBase):
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_hidden_layers: int
rms_norm_eps: float
vocab_size: int
tie_word_embeddings: bool = False
position_embedding_base: int = 0
rope_scaling: Optional[Dict[str, Any]] = None
context_window_size: int = 0
prefill_chunk_size: int = 0
num_key_value_heads: int = 0
head_dim: int = 0
tensor_parallel_shards: int = 1
pipeline_parallel_stages: int = 1
disaggregation: bool = False
...
class LlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig): ...
def embed(self, input_ids: Tensor): ...
def get_logits(self, hidden_states: Tensor): ...
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
def batch_decode(self, input_embeds, paged_kv_cache): ...
def batch_verify(self, input_embeds, paged_kv_cache): ...
def prefill_to_last_hidden_states(self, input_embed, paged_kv_cache): ...
def batch_forward_to_last_hidden_states(self, input_embeds, paged_kv_cache): ...
def batch_select_last_hidden_states(self, hidden_states, logit_positions): ...
def create_paged_kv_cache(self, ...): ...
def get_default_spec(self): ...
Import
from mlc_llm.model.llama.llama_model import LlamaConfig, LlamaForCausalLM
I/O Contract
Primary Classes
| Class | Role | Key Characteristics |
|---|---|---|
| LlamaConfig | Model configuration | Supports rope_scaling (llama3 type), pipeline parallelism, disaggregation |
| LlamaEmbedding | Shared embedding | Extends nn.Embedding with lm_head_forward for tied weights |
| LlamaFFN | Gated feed-forward | SiLU-gated via fused gate_up_proj + down_proj |
| LlamaAttention | GQA attention | Fused qkv_proj, no bias, head_dim**-0.5 scaling |
| LlamaDecoderLayer | Transformer block | input_layernorm + self_attn + post_attention_layernorm + mlp |
| LlamaModel | Core model | embed_tokens + layers with pipeline partition + norm |
| LlamaForCausalLM | Top-level model | Full inference suite including disaggregated methods |
Forward Methods
| Method | Input | Output |
|---|---|---|
embed |
Tensor[seq_len] (int32) | Tensor[seq_len, hidden_size] |
get_logits |
Tensor[seq_len, hidden_size] | Tensor[seq_len, vocab_size] (float32) |
prefill |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
decode |
Tensor[1, 1, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
batch_prefill |
Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache | (Tensor, PagedKVCache) |
batch_decode |
Tensor[batch_size, 1, hidden_size], PagedKVCache | (Tensor, PagedKVCache) |
prefill_to_last_hidden_states |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, seq_len, hidden_size], PagedKVCache) |
batch_select_last_hidden_states |
Tensor[seq_len, hidden_size], Tensor[batch_size] | Tensor[batch_size, hidden_size] |
Usage Examples
# Creating a Llama 3.1 8B config
config = LlamaConfig(
hidden_size=4096,
intermediate_size=14336,
num_attention_heads=32,
num_hidden_layers=32,
rms_norm_eps=1e-5,
vocab_size=128256,
num_key_value_heads=8,
position_embedding_base=500000,
rope_scaling={
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
context_window_size=131072,
)
model = LlamaForCausalLM(config)