Principle:NVIDIA TransformerEngine Gemma Weight Loading
Overview
Loading and mapping pretrained HuggingFace Gemma weights into TE model format with optional FP8 calibrated weights.
Description
Weight loading for the TE-Gemma model supports two distinct paths:
- Standard HF checkpoint loading: Loads weights from a HuggingFace Gemma checkpoint (safetensors format), maps the weight names from HF convention to TE convention, and copies them into the TE model's state dict. This path handles Gemma-specific features like configurable
fuse_qkv_paramsand interleaved QKV weight layout.
- FP8-calibrated weight loading: Loads a state dict that includes both the BF16 model weights and FP8 quantization metadata (amax/scale values) computed by a prior calibration procedure. This path uses
load_state_dict(strict=False)because the wrapper architecture creates multiple pointers to the same underlying weights.
The weight mapping between HF and TE follows this structure:
| HF Weight Name | TE Weight Name |
|---|---|
input_layernorm.weight |
self_attention.layernorm_qkv.layer_norm_weight
|
self_attn.q_proj.weight |
self_attention.layernorm_qkv.query_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
|
self_attn.k_proj.weight |
self_attention.layernorm_qkv.key_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
|
self_attn.v_proj.weight |
self_attention.layernorm_qkv.value_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
|
self_attn.o_proj.weight |
self_attention.proj.weight
|
post_attention_layernorm.weight |
layernorm_mlp.layer_norm_weight
|
mlp.gate_proj.weight |
layernorm_mlp.fc1_weight (first half)
|
mlp.up_proj.weight |
layernorm_mlp.fc1_weight (second half)
|
mlp.down_proj.weight |
layernorm_mlp.fc2_weight
|
Theoretical Basis
The weight mapping follows the same general principle as the Llama TE adapter but with Gemma-specific handling:
QKV Interleaving (when fuse_qkv_params=True):
When QKV weights are fused, TE stores them in an interleaved pattern within each head for optimal memory access patterns during fused attention kernels. For a model with n attention heads each of dimension d, the fused weight tensor is arranged as:
[q_head1, k_head1, v_head1, q_head2, k_head2, v_head2, ..., q_headN, k_headN, v_headN]
Each segment is of size head_dim, and the total fused tensor has shape [3 * num_heads * head_dim, hidden_size]. The interleaving copies slices from each HF projection weight into the correct offset positions in the fused tensor.
Non-Fused QKV:
When fuse_qkv_params=False, the Q, K, and V weights remain as separate tensors in TE, and the mapping is a straightforward one-to-one copy.
GeGLU FFN Weights:
The Gemma FFN uses GeGLU, which requires two "up-projection" weights: gate_proj and up_proj. In TE, these are concatenated into a single fc1_weight tensor, with gate_proj occupying the first intermediate_size rows and up_proj occupying the remaining rows.
FP8 Metadata:
FP8-calibrated weights contain additional _extra_state entries per module with amax history and scaling factors. These are loaded alongside the BF16 weights and used by TE's FP8 autocast to skip the initial calibration period. The core_attention._extra_state entries are explicitly filtered out as they are not needed.
Usage
Use this principle when loading pretrained Gemma weights into a TEGemmaForCausalLM model. The typical workflow:
- Call
load_te_model(TEGemmaForCausalLM, config) - The function creates the model inside a
quantized_model_initcontext (for FP8 weight support) - Depending on whether
config.fp8_model_weights_filenameis set, it either:- Loads FP8-calibrated weights directly
- Loads HF weights and maps them via
replace_params()
- Non-layer parameters (e.g., token embeddings, lm_head) are loaded via
load_state_dict(strict=False)