Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Gemma Load TE Model

From Leeroopedia
Revision as of 15:58, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Gemma_Load_TE_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

Weight loading and mapping functions for TE-accelerated Gemma models.

Description

The weight loading module provides two key functions:

  • load_te_model(cls, config): Creates a TEGemmaForCausalLM (or subclass) with proper weights loaded from either a HuggingFace checkpoint or an FP8-calibrated weights file. Uses quantized_model_init context manager for FP8 weight support and torch.no_grad() during initialization.
  • replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): Handles the HF-to-TE weight name mapping. Iterates over all layer prefixes matching model.layers.\d+. and copies each HF weight into the corresponding TE parameter slot. Supports both fused/interleaved and non-fused QKV weight layouts.

Additional internal functions:

  • _load_weights_for_fp8_model(vanilla_model, hyperparams): Loads FP8-calibrated state dict, filters out core_attention._extra_state entries, and loads with strict=False.
  • _load_weights_for_standard_model(vanilla_model, config): Loads HF sharded safetensors, calls replace_params(), then loads remaining parameters.
  • _get_all_layer_prefixes_to_update(hf_state_dict): Extracts all unique "model.layers.N." prefixes from the state dict.

This is a Pattern Doc.

Source

docs/examples/te_gemma/te_gemma_loading_weights.py, load_te_model at lines 75-109, replace_params at lines 127-189.

Signature

def load_te_model(cls, config):
    """
    Loads the TE model with proper weights.

    1. Sets default dtype to bfloat16
    2. Creates model under quantized_model_init and torch.no_grad contexts
    3. Loads weights via FP8 path or standard HF path
    4. Restores original dtype
    """
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(torch.bfloat16)
    config.use_cache = False

    with torch.no_grad(), quantized_model_init(config.quantized_model_init):
        vanilla_model = cls(config).cuda()

    if config.fp8_model_weights_filename is not None:
        _load_weights_for_fp8_model(vanilla_model, config)
    else:
        _load_weights_for_standard_model(vanilla_model, config)

    torch.set_default_dtype(old_dtype)
    return vanilla_model


def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
    """
    Replaces params from TE TransformerLayer state_dict with corresponding
    parameters from HuggingFace GemmaModel state_dict.

    Weight mapping for each layer prefix:
    - input_layernorm.weight -> self_attention.layernorm_qkv.layer_norm_weight
    - self_attn.o_proj.weight -> self_attention.proj.weight
    - post_attention_layernorm.weight -> layernorm_mlp.layer_norm_weight
    - mlp.down_proj.weight -> layernorm_mlp.fc2_weight
    - mlp.gate_proj.weight -> layernorm_mlp.fc1_weight[:intermediate_size]
    - mlp.up_proj.weight -> layernorm_mlp.fc1_weight[intermediate_size:]

    QKV mapping depends on qkv_fused_and_interleaved:
    - If True: interleaves q,k,v per-head into single weight tensor
    - If False: copies q,k,v into separate weight tensors
    """
    all_layer_prefixes = _get_all_layer_prefixes_to_update(hf_state_dict)
    for layer_prefix in all_layer_prefixes:
        # ... copy operations per layer ...
        if qkv_fused_and_interleaved:
            # Interleaved: [q1 k1 v1 q2 k2 v2 ...] layout
            ...
        else:
            # Separate: query_weight, key_weight, value_weight
            ...
    return all_layer_prefixes

I/O

load_te_model Input:

  • cls: Model class (e.g., TEGemmaForCausalLM or TEGemmaForCausalLMCudaGraphs)
  • config: GemmaConfig -- Extended config with:
    • config.weights_cache_dir: str -- Path to HF checkpoint directory
    • config.fp8_model_weights_filename: Optional[str] -- Path to FP8 calibrated weights, or None
    • config.quantized_model_init: Quantization configuration for quantized_model_init context
    • config.fuse_qkv_params: bool -- Whether to use fused QKV weights

load_te_model Output:

  • Loaded and initialized model instance on CUDA

replace_params Input:

  • hf_state_dict: dict -- HuggingFace model state dict
  • te_state_dict: dict -- TE model state dict (modified in-place)
  • config: GemmaConfig -- Model configuration
  • qkv_fused_and_interleaved: bool -- Whether QKV weights should be fused and interleaved (default: False)

replace_params Output:

  • set[str] -- Set of all updated layer prefixes (e.g., {"model.layers.0.", "model.layers.1.", ...})

Weight Mapping Details

FFN Weights

The GeGLU activation requires two "up-projection" matrices. In HuggingFace, these are separate (gate_proj and up_proj). In TE, they are concatenated into a single fc1_weight:

# gate_proj -> first half of fc1_weight
copy_from_ht_to_te("layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size)
# up_proj -> second half of fc1_weight
copy_from_ht_to_te("layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size)

QKV Interleaving

When qkv_fused_and_interleaved=True, the Q, K, V weights are interleaved per-head for optimal fused kernel performance:

# Layout: [q_head0, k_head0, v_head0, q_head1, k_head1, v_head1, ...]
for head_nr in range(config.num_attention_heads):
    dst_offset = head_nr * config.head_dim * 3
    # Q at offset + 0*head_dim : offset + 1*head_dim
    # K at offset + 1*head_dim : offset + 2*head_dim
    # V at offset + 2*head_dim : offset + 3*head_dim

Non-Interleaved QKV

When qkv_fused_and_interleaved=False, Q, K, V are stored as separate weight tensors:

copy_from_ht_to_te("self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight")
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te("self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight")

Notes

  • The function sets the default dtype to bfloat16 during model creation and restores the original dtype afterward.
  • config.use_cache = False is set to make TransformerLayer compatible with GemmaModel, as TE handles caching through InferenceParams rather than HF's cache mechanism.
  • The strict=False in load_state_dict() is necessary because the wrapper architecture (_model_context_phase.model and _model_generation_phase.model) creates multiple references to the same weights.
  • Remaining non-layer parameters (token embeddings, lm_head, final layer norm) are loaded via the load_state_dict(strict=False) call after replace_params() handles the layer-specific mapping.
  • Memory is explicitly freed after loading standard weights with del total_dict and gc.collect().

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment