Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:NVIDIA TransformerEngine Gemma Weight Loading

From Leeroopedia


Overview

Loading and mapping pretrained HuggingFace Gemma weights into TE model format with optional FP8 calibrated weights.

Description

Weight loading for the TE-Gemma model supports two distinct paths:

  1. Standard HF checkpoint loading: Loads weights from a HuggingFace Gemma checkpoint (safetensors format), maps the weight names from HF convention to TE convention, and copies them into the TE model's state dict. This path handles Gemma-specific features like configurable fuse_qkv_params and interleaved QKV weight layout.
  1. FP8-calibrated weight loading: Loads a state dict that includes both the BF16 model weights and FP8 quantization metadata (amax/scale values) computed by a prior calibration procedure. This path uses load_state_dict(strict=False) because the wrapper architecture creates multiple pointers to the same underlying weights.

The weight mapping between HF and TE follows this structure:

HF Weight Name TE Weight Name
input_layernorm.weight self_attention.layernorm_qkv.layer_norm_weight
self_attn.q_proj.weight self_attention.layernorm_qkv.query_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
self_attn.k_proj.weight self_attention.layernorm_qkv.key_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
self_attn.v_proj.weight self_attention.layernorm_qkv.value_weight (unfused) or part of self_attention.layernorm_qkv.weight (fused)
self_attn.o_proj.weight self_attention.proj.weight
post_attention_layernorm.weight layernorm_mlp.layer_norm_weight
mlp.gate_proj.weight layernorm_mlp.fc1_weight (first half)
mlp.up_proj.weight layernorm_mlp.fc1_weight (second half)
mlp.down_proj.weight layernorm_mlp.fc2_weight

Theoretical Basis

The weight mapping follows the same general principle as the Llama TE adapter but with Gemma-specific handling:

QKV Interleaving (when fuse_qkv_params=True):

When QKV weights are fused, TE stores them in an interleaved pattern within each head for optimal memory access patterns during fused attention kernels. For a model with n attention heads each of dimension d, the fused weight tensor is arranged as:

[q_head1, k_head1, v_head1, q_head2, k_head2, v_head2, ..., q_headN, k_headN, v_headN]

Each segment is of size head_dim, and the total fused tensor has shape [3 * num_heads * head_dim, hidden_size]. The interleaving copies slices from each HF projection weight into the correct offset positions in the fused tensor.

Non-Fused QKV:

When fuse_qkv_params=False, the Q, K, and V weights remain as separate tensors in TE, and the mapping is a straightforward one-to-one copy.

GeGLU FFN Weights:

The Gemma FFN uses GeGLU, which requires two "up-projection" weights: gate_proj and up_proj. In TE, these are concatenated into a single fc1_weight tensor, with gate_proj occupying the first intermediate_size rows and up_proj occupying the remaining rows.

FP8 Metadata:

FP8-calibrated weights contain additional _extra_state entries per module with amax history and scaling factors. These are loaded alongside the BF16 weights and used by TE's FP8 autocast to skip the initial calibration period. The core_attention._extra_state entries are explicitly filtered out as they are not needed.

Usage

Use this principle when loading pretrained Gemma weights into a TEGemmaForCausalLM model. The typical workflow:

  1. Call load_te_model(TEGemmaForCausalLM, config)
  2. The function creates the model inside a quantized_model_init context (for FP8 weight support)
  3. Depending on whether config.fp8_model_weights_filename is set, it either:
    • Loads FP8-calibrated weights directly
    • Loads HF weights and maps them via replace_params()
  4. Non-layer parameters (e.g., token embeddings, lm_head) are loaded via load_state_dict(strict=False)

Related

Sources

Domains

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment