Implementation:Mlc ai Mlc llm Gemma2 Model
| Knowledge Sources | |
|---|---|
| Domains | Model Architecture, Gemma2, Transformer |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Model architecture implementation for Google's Gemma2 language model in MLC LLM, extending the base Gemma architecture with additional layer norms, modified attention scaling, and logit soft-capping.
Description
The gemma2_model module implements the Gemma2 architecture by extending the existing Gemma model classes with Gemma2-specific modifications. The implementation reuses most of the Gemma infrastructure and overrides specific components.
Configuration:
- Gemma2Config: Extends
GemmaConfigwith additional fields:attn_logit_softcapping: Attention logit soft-capping value (currently ignored during inference per the Gemma 2 team's recommendation that removing it has minor impact)final_logit_softcapping: Applied to the final output logits via tanh-based cappingquery_pre_attn_scalar: Used to compute attention scaling factor as (head_dim / query_pre_attn_scalar)^0.5 instead of the standard 1/sqrt(head_dim)sliding_window: Overrides the context window size, as sliding window attention every other layer is not yet supported
Model Components:
- Gemma2Attention: Extends
GemmaAttentionto override the attention scaling factor using thequery_pre_attn_scalarconfiguration value.
- Gemma2DecoderLayer: A custom decoder layer that differs from the Gemma1 decoder layer by having four RMS layer norms instead of two:
input_layernorm: Pre-attention normalizationpost_attention_layernorm: Post-attention normalization (applied after the attention output, before the residual connection addition)pre_feedforward_layernorm: Pre-MLP normalizationpost_feedforward_layernorm: Post-MLP normalization (applied after MLP output, before the residual addition)- The layer also configures tensor parallel sharding strategies for QKV projection, output projection, gate/up projection, and down projection.
- The
_apply_post_matmul_normhelper handles the allreduce+norm ordering for tensor parallelism.
- Gemma2Model: Extends
GemmaModel, replacing the layer list withGemma2DecoderLayerinstances.
- Gemma2ForCausalLM: Extends
GemmaForCausalLMto:- Use
Gemma2Modelas the backbone - Apply final logit soft-capping: logits = tanh(logits / cap) * cap
- Use
Usage
Use this module when loading and serving Gemma2 models in MLC LLM. The architecture is automatically selected based on the model configuration. The model integrates with the standard MLC LLM pipeline including PagedKVCache for efficient attention, tensor parallelism for multi-GPU inference, and all supported quantization methods.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/gemma2/gemma2_model.py
Signature
@dataclasses.dataclass
class Gemma2Config(GemmaConfig):
attn_logit_softcapping: float = None
final_logit_softcapping: float = None
query_pre_attn_scalar: int = None
sliding_window: int = None
class Gemma2Attention(GemmaAttention):
def __init__(self, config: Gemma2Config)
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config)
def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int) -> Tensor
class Gemma2Model(GemmaModel):
def __init__(self, config: Gemma2Config)
class Gemma2ForCausalLM(GemmaForCausalLM):
def __init__(self, config: Gemma2Config)
def get_logits(self, hidden_states: Tensor) -> Tensor
Import
from mlc_llm.model.gemma2.gemma2_model import Gemma2Config, Gemma2ForCausalLM
I/O Contract
Gemma2DecoderLayer.forward
| Parameter | Type | Description |
|---|---|---|
| hidden_states | Tensor | Input hidden states from previous layer |
| paged_kv_cache | PagedKVCache | Paged key-value cache for attention |
| layer_id | int | The index of the current layer |
| Return | Type | Description |
|---|---|---|
| hidden_states | Tensor | Output hidden states |
Gemma2ForCausalLM.get_logits
| Parameter | Type | Description |
|---|---|---|
| hidden_states | Tensor | Final layer hidden states |
| Return | Type | Description |
|---|---|---|
| logits | Tensor | Vocabulary logits, optionally soft-capped |
Architecture Differences from Gemma1
| Feature | Gemma1 | Gemma2 |
|---|---|---|
| Layer norms per decoder layer | 2 (input, post-attention) | 4 (input, post-attention, pre-feedforward, post-feedforward) |
| Attention scaling | 1 / sqrt(head_dim) | (head_dim / query_pre_attn_scalar) ^ 0.5 |
| Final logit capping | None | tanh(logits / cap) * cap |
| Attention logit capping | None | Configured but currently disabled |
| Context window | From config | Overridden to sliding_window value |
Tensor Parallel Sharding
The decoder layer configures the following sharding strategies:
| Layer | Strategy | Details |
|---|---|---|
| self_attn.qkv_proj | ShardSingleDim | dim=0, segments=[q_heads*hd, kv_heads*hd, kv_heads*hd] |
| self_attn.o_proj | ShardSingleDim | dim=1 |
| mlp.gate_up_proj | ShardSingleDim | dim=0, segments=[intermediate, intermediate] |
| mlp.down_proj | ShardSingleDim | dim=1 |
Usage Examples
from mlc_llm.model.gemma2.gemma2_model import Gemma2Config, Gemma2ForCausalLM
# Create Gemma2 config
config = Gemma2Config(
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
query_pre_attn_scalar=256,
final_logit_softcapping=30.0,
sliding_window=4096,
)
# Instantiate model
model = Gemma2ForCausalLM(config)