Implementation:Mlc ai Mlc llm Gemma3 Model
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LLM |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Implements the Gemma3 architecture for causal language modeling within the MLC LLM framework, supporting both standalone text models and multimodal conditional generation with sliding window attention patterns.
Description
This module provides the TVM Relax-based implementation of Google's Gemma3 model architecture. It defines two top-level model classes: Gemma3LanguageModel (for text-only usage when the config has no separate text_config) and Gemma3ForCausalLM (for multimodal conditional generation). Both share the same internal components: Gemma3TextModel contains stacked Gemma3DecoderLayer modules, each consisting of a Gemma3Attention block with query/key RMSNorm and a Gemma3MLP block using GeLU activation with a gated up-projection.
Key architectural features include:
- Sliding window attention pattern: Layers use alternating "mha_sliding" and "mha" attention kinds based on a configurable
sliding_window_patternparameter (default every 6th layer uses full attention). - Query/Key normalization: Separate RMSNorm layers (
q_norm,k_norm) are applied to query and key projections before attention. - Pre-attention scalar: Uses
query_pre_attn_scalarfor softmax scaling instead of the standardhead_dim**-0.5. - Post-matmul normalization: Four RMSNorm layers per decoder layer (input, post-attention, pre-feedforward, post-feedforward).
- GemmaEmbedding reuse: Token embeddings are imported from the Gemma model and used as a tied lm_head via weight transposition.
- Tensor parallelism: Full support for sharding attention projections and MLP layers across multiple GPUs.
The configuration is handled through Gemma3Config which wraps a nested Gemma3TextConfig, automatically propagating context window, prefill chunk, and sliding window settings.
Usage
Use this module when compiling Gemma3 family models (e.g., Gemma3 1B, 4B, 12B, 27B) for deployment with MLC LLM. The Gemma3ForCausalLM class is selected for multimodal Gemma3 configs, while Gemma3LanguageModel handles text-only configs that lack a separate text_config field.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/gemma3/gemma3_model.py
Signature
@dataclasses.dataclass
class Gemma3TextConfig(ConfigBase):
hidden_size: int
intermediate_size: int
num_hidden_layers: int
attention_bias: bool = False
num_attention_heads: int = 8
num_key_value_heads: int = 4
head_dim: int = 256
rms_norm_eps: float = 1e-6
hidden_activation: Optional[str] = "gelu_pytorch_tanh"
position_embedding_base: int = 1_000_000
sliding_window_size: int = None
sliding_window_pattern = 6
...
@dataclasses.dataclass
class Gemma3Config(ConfigBase):
text_config: Gemma3TextConfig = None
vocab_size: int = 262_208
tensor_parallel_shards: int = 1
max_batch_size: int = 1
...
class Gemma3ForCausalLM(nn.Module):
def __init__(self, config: Gemma3Config): ...
def embed(self, input_ids: Tensor): ...
def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
def batch_decode(self, input_embeds, paged_kv_cache): ...
def batch_verify(self, input_embeds, paged_kv_cache): ...
def create_paged_kv_cache(self, ...): ...
def get_default_spec(self): ...
Import
from mlc_llm.model.gemma3.gemma3_model import Gemma3Config, Gemma3ForCausalLM, Gemma3LanguageModel
I/O Contract
Primary Classes
| Class | Role | Key Parameters |
|---|---|---|
| Gemma3TextConfig | Nested config for text model | hidden_size, intermediate_size, num_hidden_layers, head_dim, sliding_window_pattern |
| Gemma3Config | Top-level model config | text_config, vocab_size, tensor_parallel_shards |
| Gemma3MLP | Feed-forward network | Gated GeLU activation with gate_up_proj and down_proj |
| Gemma3Attention | Multi-head attention | Q/K RMSNorm, query_pre_attn_scalar scaling |
| Gemma3DecoderLayer | Transformer block | 4 RMSNorm layers, attention + MLP with post-matmul norms |
| Gemma3TextModel | Stacked decoder layers | embed_tokens, layers, norm |
| Gemma3LanguageModel | Text-only top-level model | model, PagedKVCache with sliding window |
| Gemma3ForCausalLM | Multimodal top-level model | language_model, PagedKVCache with sliding window |
Forward Methods
| Method | Input | Output |
|---|---|---|
embed |
Tensor[seq_len] (int32) | Tensor[1, seq_len, hidden_size] |
prefill |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
decode |
Tensor[1, 1, hidden_size], PagedKVCache | (Tensor[1, 1, vocab_size], PagedKVCache) |
batch_prefill |
Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache | (Tensor[batch_size, vocab_size], PagedKVCache) |
batch_decode |
Tensor[batch_size, 1, hidden_size], PagedKVCache | (Tensor[batch_size, vocab_size], PagedKVCache) |
batch_verify |
Tensor[1, seq_len, hidden_size], PagedKVCache | (Tensor[1, seq_len, vocab_size], PagedKVCache) |
Usage Examples
# Creating a Gemma3 config from a HuggingFace-style dictionary
config_dict = {
"text_config": {
"hidden_size": 1152,
"intermediate_size": 6912,
"num_hidden_layers": 26,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 256,
"rms_norm_eps": 1e-6,
"hidden_activation": "gelu_pytorch_tanh",
"sliding_window": 512,
"sliding_window_pattern": 6,
},
"vocab_size": 262144,
"tensor_parallel_shards": 1,
}
config = Gemma3Config.from_dict(config_dict)
model = Gemma3ForCausalLM(config)