Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Gemma3 Model

From Leeroopedia
Revision as of 15:50, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Mlc_ai_Mlc_llm_Gemma3_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Architecture, LLM
Last Updated 2026-02-09 19:00 GMT

Overview

Implements the Gemma3 architecture for causal language modeling within the MLC LLM framework, supporting both standalone text models and multimodal conditional generation with sliding window attention patterns.

Description

This module provides the TVM Relax-based implementation of Google's Gemma3 model architecture. It defines two top-level model classes: Gemma3LanguageModel (for text-only usage when the config has no separate text_config) and Gemma3ForCausalLM (for multimodal conditional generation). Both share the same internal components: Gemma3TextModel contains stacked Gemma3DecoderLayer modules, each consisting of a Gemma3Attention block with query/key RMSNorm and a Gemma3MLP block using GeLU activation with a gated up-projection.

Key architectural features include:

  • Sliding window attention pattern: Layers use alternating "mha_sliding" and "mha" attention kinds based on a configurable sliding_window_pattern parameter (default every 6th layer uses full attention).
  • Query/Key normalization: Separate RMSNorm layers (q_norm, k_norm) are applied to query and key projections before attention.
  • Pre-attention scalar: Uses query_pre_attn_scalar for softmax scaling instead of the standard head_dim**-0.5.
  • Post-matmul normalization: Four RMSNorm layers per decoder layer (input, post-attention, pre-feedforward, post-feedforward).
  • GemmaEmbedding reuse: Token embeddings are imported from the Gemma model and used as a tied lm_head via weight transposition.
  • Tensor parallelism: Full support for sharding attention projections and MLP layers across multiple GPUs.

The configuration is handled through Gemma3Config which wraps a nested Gemma3TextConfig, automatically propagating context window, prefill chunk, and sliding window settings.

Usage

Use this module when compiling Gemma3 family models (e.g., Gemma3 1B, 4B, 12B, 27B) for deployment with MLC LLM. The Gemma3ForCausalLM class is selected for multimodal Gemma3 configs, while Gemma3LanguageModel handles text-only configs that lack a separate text_config field.

Code Reference

Source Location

Signature

@dataclasses.dataclass
class Gemma3TextConfig(ConfigBase):
    hidden_size: int
    intermediate_size: int
    num_hidden_layers: int
    attention_bias: bool = False
    num_attention_heads: int = 8
    num_key_value_heads: int = 4
    head_dim: int = 256
    rms_norm_eps: float = 1e-6
    hidden_activation: Optional[str] = "gelu_pytorch_tanh"
    position_embedding_base: int = 1_000_000
    sliding_window_size: int = None
    sliding_window_pattern = 6
    ...

@dataclasses.dataclass
class Gemma3Config(ConfigBase):
    text_config: Gemma3TextConfig = None
    vocab_size: int = 262_208
    tensor_parallel_shards: int = 1
    max_batch_size: int = 1
    ...

class Gemma3ForCausalLM(nn.Module):
    def __init__(self, config: Gemma3Config): ...
    def embed(self, input_ids: Tensor): ...
    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): ...
    def batch_prefill(self, input_embeds, logit_positions, paged_kv_cache): ...
    def batch_decode(self, input_embeds, paged_kv_cache): ...
    def batch_verify(self, input_embeds, paged_kv_cache): ...
    def create_paged_kv_cache(self, ...): ...
    def get_default_spec(self): ...

Import

from mlc_llm.model.gemma3.gemma3_model import Gemma3Config, Gemma3ForCausalLM, Gemma3LanguageModel

I/O Contract

Primary Classes

Class Role Key Parameters
Gemma3TextConfig Nested config for text model hidden_size, intermediate_size, num_hidden_layers, head_dim, sliding_window_pattern
Gemma3Config Top-level model config text_config, vocab_size, tensor_parallel_shards
Gemma3MLP Feed-forward network Gated GeLU activation with gate_up_proj and down_proj
Gemma3Attention Multi-head attention Q/K RMSNorm, query_pre_attn_scalar scaling
Gemma3DecoderLayer Transformer block 4 RMSNorm layers, attention + MLP with post-matmul norms
Gemma3TextModel Stacked decoder layers embed_tokens, layers, norm
Gemma3LanguageModel Text-only top-level model model, PagedKVCache with sliding window
Gemma3ForCausalLM Multimodal top-level model language_model, PagedKVCache with sliding window

Forward Methods

Method Input Output
embed Tensor[seq_len] (int32) Tensor[1, seq_len, hidden_size]
prefill Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
decode Tensor[1, 1, hidden_size], PagedKVCache (Tensor[1, 1, vocab_size], PagedKVCache)
batch_prefill Tensor[1, seq_len, hidden_size], Tensor[batch_size], PagedKVCache (Tensor[batch_size, vocab_size], PagedKVCache)
batch_decode Tensor[batch_size, 1, hidden_size], PagedKVCache (Tensor[batch_size, vocab_size], PagedKVCache)
batch_verify Tensor[1, seq_len, hidden_size], PagedKVCache (Tensor[1, seq_len, vocab_size], PagedKVCache)

Usage Examples

# Creating a Gemma3 config from a HuggingFace-style dictionary
config_dict = {
    "text_config": {
        "hidden_size": 1152,
        "intermediate_size": 6912,
        "num_hidden_layers": 26,
        "num_attention_heads": 4,
        "num_key_value_heads": 1,
        "head_dim": 256,
        "rms_norm_eps": 1e-6,
        "hidden_activation": "gelu_pytorch_tanh",
        "sliding_window": 512,
        "sliding_window_pattern": 6,
    },
    "vocab_size": 262144,
    "tensor_parallel_shards": 1,
}

config = Gemma3Config.from_dict(config_dict)
model = Gemma3ForCausalLM(config)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment