Implementation:NVIDIA TransformerEngine Gemma Load TE Model
Overview
Weight loading and mapping functions for TE-accelerated Gemma models.
Description
The weight loading module provides two key functions:
load_te_model(cls, config): Creates aTEGemmaForCausalLM(or subclass) with proper weights loaded from either a HuggingFace checkpoint or an FP8-calibrated weights file. Usesquantized_model_initcontext manager for FP8 weight support andtorch.no_grad()during initialization.
replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): Handles the HF-to-TE weight name mapping. Iterates over all layer prefixes matchingmodel.layers.\d+.and copies each HF weight into the corresponding TE parameter slot. Supports both fused/interleaved and non-fused QKV weight layouts.
Additional internal functions:
_load_weights_for_fp8_model(vanilla_model, hyperparams): Loads FP8-calibrated state dict, filters outcore_attention._extra_stateentries, and loads withstrict=False._load_weights_for_standard_model(vanilla_model, config): Loads HF sharded safetensors, callsreplace_params(), then loads remaining parameters._get_all_layer_prefixes_to_update(hf_state_dict): Extracts all unique"model.layers.N."prefixes from the state dict.
This is a Pattern Doc.
Source
docs/examples/te_gemma/te_gemma_loading_weights.py, load_te_model at lines 75-109, replace_params at lines 127-189.
Signature
def load_te_model(cls, config):
"""
Loads the TE model with proper weights.
1. Sets default dtype to bfloat16
2. Creates model under quantized_model_init and torch.no_grad contexts
3. Loads weights via FP8 path or standard HF path
4. Restores original dtype
"""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
config.use_cache = False
with torch.no_grad(), quantized_model_init(config.quantized_model_init):
vanilla_model = cls(config).cuda()
if config.fp8_model_weights_filename is not None:
_load_weights_for_fp8_model(vanilla_model, config)
else:
_load_weights_for_standard_model(vanilla_model, config)
torch.set_default_dtype(old_dtype)
return vanilla_model
def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
"""
Replaces params from TE TransformerLayer state_dict with corresponding
parameters from HuggingFace GemmaModel state_dict.
Weight mapping for each layer prefix:
- input_layernorm.weight -> self_attention.layernorm_qkv.layer_norm_weight
- self_attn.o_proj.weight -> self_attention.proj.weight
- post_attention_layernorm.weight -> layernorm_mlp.layer_norm_weight
- mlp.down_proj.weight -> layernorm_mlp.fc2_weight
- mlp.gate_proj.weight -> layernorm_mlp.fc1_weight[:intermediate_size]
- mlp.up_proj.weight -> layernorm_mlp.fc1_weight[intermediate_size:]
QKV mapping depends on qkv_fused_and_interleaved:
- If True: interleaves q,k,v per-head into single weight tensor
- If False: copies q,k,v into separate weight tensors
"""
all_layer_prefixes = _get_all_layer_prefixes_to_update(hf_state_dict)
for layer_prefix in all_layer_prefixes:
# ... copy operations per layer ...
if qkv_fused_and_interleaved:
# Interleaved: [q1 k1 v1 q2 k2 v2 ...] layout
...
else:
# Separate: query_weight, key_weight, value_weight
...
return all_layer_prefixes
I/O
load_te_model Input:
cls: Model class (e.g.,TEGemmaForCausalLMorTEGemmaForCausalLMCudaGraphs)config:GemmaConfig-- Extended config with:config.weights_cache_dir:str-- Path to HF checkpoint directoryconfig.fp8_model_weights_filename:Optional[str]-- Path to FP8 calibrated weights, orNoneconfig.quantized_model_init: Quantization configuration forquantized_model_initcontextconfig.fuse_qkv_params:bool-- Whether to use fused QKV weights
load_te_model Output:
- Loaded and initialized model instance on CUDA
replace_params Input:
hf_state_dict:dict-- HuggingFace model state dictte_state_dict:dict-- TE model state dict (modified in-place)config:GemmaConfig-- Model configurationqkv_fused_and_interleaved:bool-- Whether QKV weights should be fused and interleaved (default:False)
replace_params Output:
set[str]-- Set of all updated layer prefixes (e.g.,{"model.layers.0.", "model.layers.1.", ...})
Weight Mapping Details
FFN Weights
The GeGLU activation requires two "up-projection" matrices. In HuggingFace, these are separate (gate_proj and up_proj). In TE, they are concatenated into a single fc1_weight:
# gate_proj -> first half of fc1_weight
copy_from_ht_to_te("layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size)
# up_proj -> second half of fc1_weight
copy_from_ht_to_te("layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size)
QKV Interleaving
When qkv_fused_and_interleaved=True, the Q, K, V weights are interleaved per-head for optimal fused kernel performance:
# Layout: [q_head0, k_head0, v_head0, q_head1, k_head1, v_head1, ...]
for head_nr in range(config.num_attention_heads):
dst_offset = head_nr * config.head_dim * 3
# Q at offset + 0*head_dim : offset + 1*head_dim
# K at offset + 1*head_dim : offset + 2*head_dim
# V at offset + 2*head_dim : offset + 3*head_dim
Non-Interleaved QKV
When qkv_fused_and_interleaved=False, Q, K, V are stored as separate weight tensors:
copy_from_ht_to_te("self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight")
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te("self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight")
Notes
- The function sets the default dtype to
bfloat16during model creation and restores the original dtype afterward. config.use_cache = Falseis set to makeTransformerLayercompatible withGemmaModel, as TE handles caching throughInferenceParamsrather than HF's cache mechanism.- The
strict=Falseinload_state_dict()is necessary because the wrapper architecture (_model_context_phase.modeland_model_generation_phase.model) creates multiple references to the same weights. - Remaining non-layer parameters (token embeddings, lm_head, final layer norm) are loaded via the
load_state_dict(strict=False)call afterreplace_params()handles the layer-specific mapping. - Memory is explicitly freed after loading standard weights with
del total_dictandgc.collect().