Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Mlc ai Mlc llm Gemma2 Model

From Leeroopedia


Knowledge Sources
Domains Model Architecture, Gemma2, Transformer
Last Updated 2026-02-09 19:00 GMT

Overview

Model architecture implementation for Google's Gemma2 language model in MLC LLM, extending the base Gemma architecture with additional layer norms, modified attention scaling, and logit soft-capping.

Description

The gemma2_model module implements the Gemma2 architecture by extending the existing Gemma model classes with Gemma2-specific modifications. The implementation reuses most of the Gemma infrastructure and overrides specific components.

Configuration:

  • Gemma2Config: Extends GemmaConfig with additional fields:
    • attn_logit_softcapping: Attention logit soft-capping value (currently ignored during inference per the Gemma 2 team's recommendation that removing it has minor impact)
    • final_logit_softcapping: Applied to the final output logits via tanh-based capping
    • query_pre_attn_scalar: Used to compute attention scaling factor as (head_dim / query_pre_attn_scalar)^0.5 instead of the standard 1/sqrt(head_dim)
    • sliding_window: Overrides the context window size, as sliding window attention every other layer is not yet supported

Model Components:

  • Gemma2Attention: Extends GemmaAttention to override the attention scaling factor using the query_pre_attn_scalar configuration value.
  • Gemma2DecoderLayer: A custom decoder layer that differs from the Gemma1 decoder layer by having four RMS layer norms instead of two:
    • input_layernorm: Pre-attention normalization
    • post_attention_layernorm: Post-attention normalization (applied after the attention output, before the residual connection addition)
    • pre_feedforward_layernorm: Pre-MLP normalization
    • post_feedforward_layernorm: Post-MLP normalization (applied after MLP output, before the residual addition)
    • The layer also configures tensor parallel sharding strategies for QKV projection, output projection, gate/up projection, and down projection.
    • The _apply_post_matmul_norm helper handles the allreduce+norm ordering for tensor parallelism.
  • Gemma2Model: Extends GemmaModel, replacing the layer list with Gemma2DecoderLayer instances.
  • Gemma2ForCausalLM: Extends GemmaForCausalLM to:
    • Use Gemma2Model as the backbone
    • Apply final logit soft-capping: logits = tanh(logits / cap) * cap

Usage

Use this module when loading and serving Gemma2 models in MLC LLM. The architecture is automatically selected based on the model configuration. The model integrates with the standard MLC LLM pipeline including PagedKVCache for efficient attention, tensor parallelism for multi-GPU inference, and all supported quantization methods.

Code Reference

Source Location

Signature

@dataclasses.dataclass
class Gemma2Config(GemmaConfig):
    attn_logit_softcapping: float = None
    final_logit_softcapping: float = None
    query_pre_attn_scalar: int = None
    sliding_window: int = None

class Gemma2Attention(GemmaAttention):
    def __init__(self, config: Gemma2Config)

class Gemma2DecoderLayer(nn.Module):
    def __init__(self, config: Gemma2Config)
    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int) -> Tensor

class Gemma2Model(GemmaModel):
    def __init__(self, config: Gemma2Config)

class Gemma2ForCausalLM(GemmaForCausalLM):
    def __init__(self, config: Gemma2Config)
    def get_logits(self, hidden_states: Tensor) -> Tensor

Import

from mlc_llm.model.gemma2.gemma2_model import Gemma2Config, Gemma2ForCausalLM

I/O Contract

Gemma2DecoderLayer.forward

Parameter Type Description
hidden_states Tensor Input hidden states from previous layer
paged_kv_cache PagedKVCache Paged key-value cache for attention
layer_id int The index of the current layer
Return Type Description
hidden_states Tensor Output hidden states

Gemma2ForCausalLM.get_logits

Parameter Type Description
hidden_states Tensor Final layer hidden states
Return Type Description
logits Tensor Vocabulary logits, optionally soft-capped

Architecture Differences from Gemma1

Feature Gemma1 Gemma2
Layer norms per decoder layer 2 (input, post-attention) 4 (input, post-attention, pre-feedforward, post-feedforward)
Attention scaling 1 / sqrt(head_dim) (head_dim / query_pre_attn_scalar) ^ 0.5
Final logit capping None tanh(logits / cap) * cap
Attention logit capping None Configured but currently disabled
Context window From config Overridden to sliding_window value

Tensor Parallel Sharding

The decoder layer configures the following sharding strategies:

Layer Strategy Details
self_attn.qkv_proj ShardSingleDim dim=0, segments=[q_heads*hd, kv_heads*hd, kv_heads*hd]
self_attn.o_proj ShardSingleDim dim=1
mlp.gate_up_proj ShardSingleDim dim=0, segments=[intermediate, intermediate]
mlp.down_proj ShardSingleDim dim=1

Usage Examples

from mlc_llm.model.gemma2.gemma2_model import Gemma2Config, Gemma2ForCausalLM

# Create Gemma2 config
config = Gemma2Config(
    vocab_size=256000,
    hidden_size=3072,
    intermediate_size=24576,
    num_hidden_layers=28,
    num_attention_heads=16,
    num_key_value_heads=16,
    head_dim=256,
    query_pre_attn_scalar=256,
    final_logit_softcapping=30.0,
    sliding_window=4096,
)

# Instantiate model
model = Gemma2ForCausalLM(config)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment