Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Llama Replace Params

From Leeroopedia


Overview

Weight mapping function for converting HuggingFace LLaMA state dict to TransformerEngine TransformerLayer format.

Doc Type

Pattern Doc -- This function implements the weight mapping pattern for translating between HF and TE parameter naming and layout conventions.

Description

replace_params(hf_state_dict, te_state_dict, config) copies parameters from a HuggingFace LlamaModel state dict into a TE TransformerLayer state dict. The function handles the key differences between HF and TE parameter layouts:

  • Fused QKV: Maps separate q_proj, k_proj, v_proj weights to TE's layernorm_qkv sub-module with separate query_weight, key_weight, value_weight parameters.
  • Fused MLP: Concatenates HF's gate_proj and up_proj into TE's single fc1_weight tensor, where gate occupies the first config.intermediate_size rows and up occupies the remainder.
  • Layer Norm: Maps input_layernorm and post_attention_layernorm to their TE equivalents within the fused modules.

The function discovers all layer prefixes (e.g., model.layers.0., model.layers.1., etc.) by regex-matching keys in the HF state dict. This makes it robust to models with different numbers of layers and to sharded checkpoints where not all layers are present in every shard.

Each parameter is copied with an existence check (if layer_prefix + "param_name" in hf_state_dict), allowing the function to handle partial state dicts from sharded checkpoints where gate and up projection weights may be in different files.

All copies are performed in-place using .data[:] = slice assignment, which avoids creating new tensors and preserves the TE model's existing parameter objects (important for optimizers and gradient tracking).

Source

  • File: docs/examples/te_llama/te_llama.py
  • Function: replace_params
  • Lines: L166-224

Signature

def replace_params(hf_state_dict, te_state_dict, config):
    # collect all layer prefixes to update
    all_layer_prefixes = set()
    for param_key in hf_state_dict.keys():
        layer_prefix_pat = r"model.layers.\d+."
        m = re.match(layer_prefix_pat, param_key)
        if m is not None:
            all_layer_prefixes.add(m.group())

    for layer_prefix in all_layer_prefixes:
        # input_layernorm.weight -> self_attention.layernorm_qkv.layer_norm_weight
        if layer_prefix + "input_layernorm.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
            ].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]

        # self_attn.q_proj.weight -> self_attention.layernorm_qkv.query_weight
        if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "self_attention.layernorm_qkv.query_weight"
            ].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]

        # self_attn.k_proj.weight -> self_attention.layernorm_qkv.key_weight
        if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "self_attention.layernorm_qkv.key_weight"
            ].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]

        # self_attn.v_proj.weight -> self_attention.layernorm_qkv.value_weight
        if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "self_attention.layernorm_qkv.value_weight"
            ].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]

        # self_attn.o_proj.weight -> self_attention.proj.weight
        if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "self_attention.proj.weight"
            ].data[:] = hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]

        # post_attention_layernorm.weight -> layernorm_mlp.layer_norm_weight
        if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "layernorm_mlp.layer_norm_weight"
            ].data[:] = hf_state_dict[
                layer_prefix + "post_attention_layernorm.weight"
            ].data[:]

        # mlp.gate_proj.weight -> layernorm_mlp.fc1_weight[:intermediate_size]
        if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                : config.intermediate_size
            ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data

        # mlp.up_proj.weight -> layernorm_mlp.fc1_weight[intermediate_size:]
        if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
            te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
                config.intermediate_size :
            ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data

        # mlp.down_proj.weight -> layernorm_mlp.fc2_weight
        if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
            te_state_dict[
                layer_prefix + "layernorm_mlp.fc2_weight"
            ].data[:] = hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]

    return all_layer_prefixes

I/O

Direction Name Type Description
Input hf_state_dict dict[str, torch.Tensor] State dict loaded from a HuggingFace checkpoint shard. Contains HF-named parameter tensors.
Input te_state_dict dict[str, torch.Tensor] State dict from the TE model (obtained via model.state_dict()). Modified in-place.
Input config LlamaConfig HuggingFace LLaMA configuration. Used for config.intermediate_size to determine the split point in fc1_weight.
Output return value set[str] Set of layer prefixes that were processed (e.g., {"model.layers.0.", "model.layers.1.", ...}).
Side Effect te_state_dict modification in-place TE parameters are updated in-place via .data[:] = slice assignment.

Weight Mapping Table

For each layer prefix model.layers.N.:

HF Key Suffix TE Key Suffix Copy Method
input_layernorm.weight self_attention.layernorm_qkv.layer_norm_weight Direct copy (.data[:] =)
self_attn.q_proj.weight self_attention.layernorm_qkv.query_weight Direct copy
self_attn.k_proj.weight self_attention.layernorm_qkv.key_weight Direct copy
self_attn.v_proj.weight self_attention.layernorm_qkv.value_weight Direct copy
self_attn.o_proj.weight self_attention.proj.weight Direct copy
post_attention_layernorm.weight layernorm_mlp.layer_norm_weight Direct copy
mlp.gate_proj.weight layernorm_mlp.fc1_weight[:intermediate_size] Slice assignment (first half)
mlp.up_proj.weight layernorm_mlp.fc1_weight[intermediate_size:] Slice assignment (second half)
mlp.down_proj.weight layernorm_mlp.fc2_weight Direct copy

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment