Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Llama Model

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, LLM
Last Updated 2026-02-09 19:00 GMT

Overview

Implements the Llama architecture for causal language modeling within the MLC LLM framework, supporting grouped-query attention, RoPE with llama3-style scaling, pipeline parallelism, disaggregated inference, and optional tied word embeddings.

Description

This module provides the TVM Relax-based implementation of Meta's Llama model architecture (Llama 2, Llama 3, Llama 3.1, Llama 3.2, and compatible models). It is one of the most feature-rich model implementations in MLC LLM, serving as the reference architecture for many derivative models.

Key architectural features include:

  • Grouped-query attention (GQA): Supports configurable numbers of query and key-value heads via num_attention_heads and num_key_value_heads.
  • Fused QKV projection: A single qkv_proj linear layer produces all query, key, and value projections in one operation.
  • SiLU-gated MLP: Uses a fused gate_up_proj combining gate and up projections, followed by SiLU activation and a down_proj.
  • RoPE with llama3 scaling: Supports standard RoPE and llama3-specific rope scaling with configurable factor, low_freq_factor, high_freq_factor, and original_max_position_embeddings.
  • Tied word embeddings: Optionally shares the embedding table with the LM head via a custom LlamaEmbedding class that transposes weights for the output projection.
  • Pipeline parallelism: Supports partitioning layers across pipeline stages with pipeline_parallel_stages and automatic boundary insertion.
  • Disaggregated inference: Provides dedicated methods for extracting last hidden states (prefill_to_last_hidden_states, batch_forward_to_last_hidden_states, etc.) enabling disaggregated prefill/decode workflows.
  • RMSNorm: Uses RMSNorm for input and post-attention normalization.

The module defines LlamaModel (embedding + decoder layers with pipeline partitioning + final RMSNorm), wrapped by LlamaForCausalLM which adds the LM head and the full suite of inference methods.

Usage

Use this module when compiling Llama family models (Llama 2, Llama 3, Llama 3.1, Llama 3.2, CodeLlama, TinyLlama, SmolLM, and other LlamaForCausalLM-architecture models) for deployment with MLC LLM.

Code Reference

Source Location

Signature

@dataclasses.dataclass
class LlamaConfig(ConfigBase):
    hidden_size: int
    intermediate_size: int
    num_attention_heads: int
    num_hidden_layers: int
    rms_norm_eps: float
    vocab_size: int
    tie_word_embeddings: bool = False
    position_embedding_base: int = 0
    rope_scaling: Optional[Dict[str, Any]] = None
    context_window_size: int = 0
    prefill_chunk_size: int = 0
    num_key_value_heads: int = 0
    head_dim: int = 0
    tensor_parallel_shards: int = 1
    pipeline_parallel_stages: int = 1
    disaggregation: bool = False
    ...

class LlamaForCausalLM(nn.Module):
    def __init__(self, config: LlamaConfig): ...
    def embed(self, input_ids: Tensor): ...
    def get_logits(self, hidden_states: Tensor): ...
    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
    def batch_decode(self, input_embeds, paged_kv_cache): ...
    def batch_verify(self, input_embeds, paged_kv_cache): ...
    def prefill_to_last_hidden_states(self, input_embed, paged_kv_cache): ...
    def batch_forward_to_last_hidden_states(self, input_embeds, paged_kv_cache): ...
    def batch_select_last_hidden_states(self, hidden_states, logit_positions): ...
    def create_paged_kv_cache(self, ...): ...
    def get_default_spec(self): ...

Import

from mlc_llm.model.llama.llama_model import LlamaConfig, LlamaForCausalLM

I/O Contract

Primary Classes

Class Role Key Characteristics
LlamaConfig Model configuration Supports rope_scaling (llama3 type), pipeline parallelism, disaggregation
LlamaEmbedding Shared embedding Extends nn.Embedding with lm_head_forward for tied weights
LlamaFFN Gated feed-forward SiLU-gated via fused gate_up_proj + down_proj
LlamaAttention GQA attention Fused qkv_proj, no bias, head_dim**-0.5 scaling
LlamaDecoderLayer Transformer block input_layernorm + self_attn + post_attention_layernorm + mlp
LlamaModel Core model embed_tokens + layers with pipeline partition + norm
LlamaForCausalLM Top-level model Full inference suite including disaggregated methods

Forward Methods

Method Input Output
embed Tensor[seq_len] (int32) Tensor[seq_len, hidden_size]
get_logits Tensor[seq_len, hidden_size] Tensor[seq_len, vocab_size] (float32)
prefill Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
decode Tensor[1, 1, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
batch_prefill Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache (Tensor, PagedKVCache)
batch_decode Tensor[batch_size, 1, hidden_size], PagedKVCache (Tensor, PagedKVCache)
prefill_to_last_hidden_states Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor[1, seq_len, hidden_size], PagedKVCache)
batch_select_last_hidden_states Tensor[seq_len, hidden_size], Tensor[batch_size] Tensor[batch_size, hidden_size]

Usage Examples

# Creating a Llama 3.1 8B config
config = LlamaConfig(
    hidden_size=4096,
    intermediate_size=14336,
    num_attention_heads=32,
    num_hidden_layers=32,
    rms_norm_eps=1e-5,
    vocab_size=128256,
    num_key_value_heads=8,
    position_embedding_base=500000,
    rope_scaling={
        "factor": 8.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    context_window_size=131072,
)
model = LlamaForCausalLM(config)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment