Workflow:NVIDIA TransformerEngine Accelerate HF Llama With TE
| Knowledge Sources | |
|---|---|
| Domains | LLMs, FP8_Training, HuggingFace_Integration |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
End-to-end process for accelerating HuggingFace LLaMA models by replacing their decoder layers with Transformer Engine's optimized TransformerLayer, enabling FP8 training and inference.
Description
This workflow demonstrates how to integrate NVIDIA Transformer Engine with HuggingFace's LLaMA model implementation. Rather than rewriting the entire model, it uses a monkey-patching strategy to replace the standard LlamaDecoderLayer with a TE-backed TELlamaDecoderLayer that inherits from te.TransformerLayer. This approach preserves full compatibility with the HuggingFace ecosystem (tokenizers, training loops, checkpointing) while gaining FP8 acceleration and fused kernel optimizations from TE.
Key outputs:
- A HuggingFace-compatible LLaMA model with TE-accelerated decoder layers
- Ability to load pretrained weights from HuggingFace checkpoints into the TE model
- FP8 training support with minimal code changes
Usage
Execute this workflow when you have a pretrained LLaMA model (LLaMA-2, LLaMA-3, or compatible variants) and want to fine-tune or continue training it with FP8 precision on NVIDIA Hopper/Ada/Blackwell GPUs, while maintaining compatibility with the HuggingFace Transformers library for data loading, tokenization, and model management.
Execution Steps
Step 1: Define TE Decoder Layer Wrapper
Create a TELlamaDecoderLayer class that inherits from te.pytorch.TransformerLayer. Map the LLaMA model configuration parameters (hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, rms_norm_eps) to their TE equivalents. Configure TE-specific options such as RMSNorm normalization, SwiGLU activation, BSHD attention format, and Grouped Query Attention groups.
Key considerations:
- Set normalization to "RMSNorm" to match LLaMA architecture
- Set activation to "swiglu" for the gated MLP
- Set fuse_qkv_params to False since LLaMA uses separate Q, K, V projections in its checkpoints
- Configure num_gqa_groups for models with Grouped Query Attention
- Initialize RotaryPositionEmbedding with the correct head dimension
Step 2: Create Context Manager for Layer Replacement
Implement a context manager that temporarily monkey-patches the LlamaDecoderLayer class in the HuggingFace transformers module with the TE wrapper. This allows the standard HuggingFace model initialization code to create TE-backed layers instead of the original layers, without modifying the HuggingFace library itself.
What happens:
- The original LlamaDecoderLayer class reference is saved
- The class is replaced with TELlamaDecoderLayer in the transformers module
- After model initialization, the original class is restored
Step 3: Initialize Model With TE Layers
Use the context manager to wrap the creation of LlamaForCausalLM. When the HuggingFace model initialization code creates decoder layers, it will instantiate TELlamaDecoderLayer instead of the standard LlamaDecoderLayer. This produces a complete LLaMA model with TE-optimized transformer blocks.
Key considerations:
- The model retains all non-decoder components (embeddings, final norm, LM head) from HuggingFace
- Only the transformer decoder layers are replaced with TE equivalents
- The model is fully compatible with HuggingFace's training and inference APIs
Step 4: Load Pretrained Weights
Map and transfer pretrained weights from the HuggingFace checkpoint format to the TE layer format. This involves renaming and reorganizing weight tensors since TE's fused layers use different parameter names than HuggingFace's separate modules. Handle sharded checkpoints by iterating through each shard file and loading weights incrementally.
Weight mapping:
- input_layernorm.weight maps to self_attention.layernorm_qkv.layer_norm_weight
- self_attn.q/k/v_proj.weight maps to self_attention.layernorm_qkv.query/key/value_weight
- self_attn.o_proj.weight maps to self_attention.proj.weight
- post_attention_layernorm.weight maps to layernorm_mlp.layer_norm_weight
- mlp.gate_proj/up_proj/down_proj.weight maps to layernorm_mlp.fc1/fc2_weight
Step 5: Enable FP8 Training
Wrap the training forward pass with te.autocast to enable FP8 precision. Configure the FP8 recipe (DelayedScaling or Float8CurrentScaling) and run the standard HuggingFace training loop. The TE layers automatically handle FP8 quantization and dequantization of activations and weights during the forward and backward passes.
Key considerations:
- Use HYBRID format (E4M3 forward, E5M2 backward) for best training stability
- The autocast context must wrap only the forward pass
- Compatible with HuggingFace Trainer and custom training loops