Implementation:Mlc ai Mlc llm Gemma2 Loader
| Knowledge Sources | |
|---|---|
| Domains | Model Loading, Parameter Mapping, Gemma2 |
| Last Updated | 2026-02-09 19:00 GMT |
Overview
Parameter mapping and weight loading logic for the Gemma2 model architecture, converting HuggingFace PyTorch parameter names and formats to MLC LLM's internal representation.
Description
The gemma2_loader module provides the huggingface function that builds a parameter mapping (ExternMapping) between HuggingFace Gemma2 model weights and MLC LLM's internal parameter format. This mapping is essential for loading pre-trained Gemma2 weights into the MLC LLM model architecture.
Key Transformations:
The loader handles several weight transformations specific to the Gemma2 architecture:
- QKV Projection Fusion: Separate
q_proj,k_proj, andv_projweight matrices from HuggingFace are concatenated along axis 0 into a singleqkv_proj.weighttensor for MLC LLM's fused QKV projection.
- Gate/Up Projection Fusion: Separate
gate_projandup_projweight matrices from HuggingFace are concatenated along axis 0 into a singlegate_up_proj.weighttensor.
- RMS LayerNorm Weight Offset: Gemma models use a variant of RMS normalization where the weights have an implicit +1 offset. The loader adds 1 to all RMS norm weights during loading so this addition does not need to happen at runtime. This applies to all four per-layer norms (input_layernorm, post_attention_layernorm, pre_feedforward_layernorm, post_feedforward_layernorm) as well as the final model.norm.
- Dtype Casting: All parameters are cast to their target MLC parameter dtype during the mapping transformation.
- Identity Mapping: Any remaining parameters not explicitly handled (e.g., embedding weights, output projection) receive a simple identity mapping with dtype casting.
Implementation Approach:
The function instantiates the full Gemma2ForCausalLM model, exports it to TVM to obtain the named parameter list, and then iterates through the layers to build the mappings. It uses functools.partial with lambda functions to capture the target dtype for each transformation.
Usage
Use this module when loading Gemma2 model weights from HuggingFace format into MLC LLM. The huggingface function is called by the model compilation pipeline to produce the parameter mapping needed by the weight loader. It supports both quantized and non-quantized loading paths.
Code Reference
Source Location
- Repository: Mlc_ai_Mlc_llm
- File: python/mlc_llm/model/gemma2/gemma2_loader.py
Signature
def huggingface(
model_config: Gemma2Config,
quantization: Quantization,
) -> ExternMapping
Import
from mlc_llm.model.gemma2.gemma2_loader import huggingface
I/O Contract
huggingface
| Parameter | Type | Description |
|---|---|---|
| model_config | Gemma2Config | The Gemma2 model configuration |
| quantization | Quantization | The quantization configuration (can be None for unquantized) |
| Return | Type | Description |
|---|---|---|
| param_map | ExternMapping | Parameter name/transform mapping from HuggingFace to MLC LLM format |
Parameter Mapping Table
The following table summarizes the parameter name transformations per decoder layer (layer index i):
| MLC LLM Parameter | HuggingFace Source(s) | Transformation |
|---|---|---|
| model.layers.{i}.self_attn.qkv_proj.weight | q_proj.weight, k_proj.weight, v_proj.weight | np.concatenate([q, k, v], axis=0) |
| model.layers.{i}.mlp.gate_up_proj.weight | gate_proj.weight, up_proj.weight | np.concatenate([gate, up], axis=0) |
| model.layers.{i}.input_layernorm.weight | input_layernorm.weight | x + 1 |
| model.layers.{i}.post_attention_layernorm.weight | post_attention_layernorm.weight | x + 1 |
| model.layers.{i}.pre_feedforward_layernorm.weight | pre_feedforward_layernorm.weight | x + 1 |
| model.layers.{i}.post_feedforward_layernorm.weight | post_feedforward_layernorm.weight | x + 1 |
| model.norm.weight | model.norm.weight | x + 1 |
| (all others) | (same name) | Identity with dtype cast |
Usage Examples
from mlc_llm.model.gemma2.gemma2_model import Gemma2Config
from mlc_llm.model.gemma2.gemma2_loader import huggingface
# Create model config
config = Gemma2Config(
vocab_size=256000,
hidden_size=3072,
intermediate_size=24576,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=16,
head_dim=256,
query_pre_attn_scalar=256,
final_logit_softcapping=30.0,
sliding_window=4096,
)
# Build parameter mapping for HuggingFace weights
param_mapping = huggingface(model_config=config, quantization=None)
# The mapping can then be used by the weight loader to convert
# HuggingFace checkpoint files to MLC LLM format
for mlc_name, (hf_names, transform_fn) in param_mapping.param_map.items():
# Load HF tensors by hf_names, apply transform_fn, store as mlc_name
pass