Implementation:NVIDIA TransformerEngine Llama Replace Params
Overview
Weight mapping function for converting HuggingFace LLaMA state dict to TransformerEngine TransformerLayer format.
Doc Type
Pattern Doc -- This function implements the weight mapping pattern for translating between HF and TE parameter naming and layout conventions.
Description
replace_params(hf_state_dict, te_state_dict, config) copies parameters from a HuggingFace LlamaModel state dict into a TE TransformerLayer state dict. The function handles the key differences between HF and TE parameter layouts:
- Fused QKV: Maps separate
q_proj,k_proj,v_projweights to TE'slayernorm_qkvsub-module with separatequery_weight,key_weight,value_weightparameters. - Fused MLP: Concatenates HF's
gate_projandup_projinto TE's singlefc1_weighttensor, where gate occupies the firstconfig.intermediate_sizerows and up occupies the remainder. - Layer Norm: Maps
input_layernormandpost_attention_layernormto their TE equivalents within the fused modules.
The function discovers all layer prefixes (e.g., model.layers.0., model.layers.1., etc.) by regex-matching keys in the HF state dict. This makes it robust to models with different numbers of layers and to sharded checkpoints where not all layers are present in every shard.
Each parameter is copied with an existence check (if layer_prefix + "param_name" in hf_state_dict), allowing the function to handle partial state dicts from sharded checkpoints where gate and up projection weights may be in different files.
All copies are performed in-place using .data[:] = slice assignment, which avoids creating new tensors and preserves the TE model's existing parameter objects (important for optimizers and gradient tracking).
Source
- File:
docs/examples/te_llama/te_llama.py - Function:
replace_params - Lines: L166-224
Signature
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = r"model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# input_layernorm.weight -> self_attention.layernorm_qkv.layer_norm_weight
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
# self_attn.q_proj.weight -> self_attention.layernorm_qkv.query_weight
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.query_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
# self_attn.k_proj.weight -> self_attention.layernorm_qkv.key_weight
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.key_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
# self_attn.v_proj.weight -> self_attention.layernorm_qkv.value_weight
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.value_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
# self_attn.o_proj.weight -> self_attention.proj.weight
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.proj.weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
# post_attention_layernorm.weight -> layernorm_mlp.layer_norm_weight
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "layernorm_mlp.layer_norm_weight"
].data[:] = hf_state_dict[
layer_prefix + "post_attention_layernorm.weight"
].data[:]
# mlp.gate_proj.weight -> layernorm_mlp.fc1_weight[:intermediate_size]
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
: config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
# mlp.up_proj.weight -> layernorm_mlp.fc1_weight[intermediate_size:]
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
# mlp.down_proj.weight -> layernorm_mlp.fc2_weight
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "layernorm_mlp.fc2_weight"
].data[:] = hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
return all_layer_prefixes
I/O
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | hf_state_dict |
dict[str, torch.Tensor] |
State dict loaded from a HuggingFace checkpoint shard. Contains HF-named parameter tensors. |
| Input | te_state_dict |
dict[str, torch.Tensor] |
State dict from the TE model (obtained via model.state_dict()). Modified in-place.
|
| Input | config |
LlamaConfig |
HuggingFace LLaMA configuration. Used for config.intermediate_size to determine the split point in fc1_weight.
|
| Output | return value | set[str] |
Set of layer prefixes that were processed (e.g., {"model.layers.0.", "model.layers.1.", ...}).
|
| Side Effect | te_state_dict modification |
in-place | TE parameters are updated in-place via .data[:] = slice assignment.
|
Weight Mapping Table
For each layer prefix model.layers.N.:
| HF Key Suffix | TE Key Suffix | Copy Method |
|---|---|---|
input_layernorm.weight |
self_attention.layernorm_qkv.layer_norm_weight |
Direct copy (.data[:] =)
|
self_attn.q_proj.weight |
self_attention.layernorm_qkv.query_weight |
Direct copy |
self_attn.k_proj.weight |
self_attention.layernorm_qkv.key_weight |
Direct copy |
self_attn.v_proj.weight |
self_attention.layernorm_qkv.value_weight |
Direct copy |
self_attn.o_proj.weight |
self_attention.proj.weight |
Direct copy |
post_attention_layernorm.weight |
layernorm_mlp.layer_norm_weight |
Direct copy |
mlp.gate_proj.weight |
layernorm_mlp.fc1_weight[:intermediate_size] |
Slice assignment (first half) |
mlp.up_proj.weight |
layernorm_mlp.fc1_weight[intermediate_size:] |
Slice assignment (second half) |
mlp.down_proj.weight |
layernorm_mlp.fc2_weight |
Direct copy |