Workflow:NVIDIA TransformerEngine Accelerate HF Gemma With TE
| Knowledge Sources | |
|---|---|
| Domains | LLMs, FP8_Inference, HuggingFace_Integration |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
End-to-end process for accelerating HuggingFace Gemma models with Transformer Engine, enabling FP8 precision for both training and inference with KV cache support.
Description
This workflow demonstrates how to integrate NVIDIA Transformer Engine with HuggingFace's Gemma model for both training and autoregressive inference (text generation). It follows the same monkey-patching strategy as the LLaMA integration but adds inference-specific features: KV cache management through TE's InferenceParams, autoregressive token generation support, and compatibility with HuggingFace's generate() API. The implementation creates a TEGemmaDecoderLayer that inherits from te.TransformerLayer and a TEGemmaForCausalLM entry point for generation.
Key outputs:
- A HuggingFace-compatible Gemma model with TE-accelerated decoder layers
- FP8-accelerated autoregressive text generation with KV caching
- Weight loading from pretrained HuggingFace Gemma checkpoints
Usage
Execute this workflow when you need to accelerate HuggingFace Gemma model inference or training with FP8 precision, particularly when using the model.generate() API for text generation. This is suitable for both fine-tuning and serving Gemma models on NVIDIA GPUs with FP8 support.
Execution Steps
Step 1: Define TE Decoder Layer for Gemma
Create a TEGemmaDecoderLayer class that inherits from te.pytorch.TransformerLayer and maps Gemma's architecture configuration to TE parameters. Gemma uses RMSNorm, GeGLU activation, and has specific attention patterns that must be correctly configured. Initialize RotaryPositionEmbedding for positional encoding.
Key considerations:
- Gemma uses GeGLU activation (not SwiGLU) in its MLP
- Configure hidden_size, intermediate_size, and attention head counts from GemmaConfig
- Set up RoPE embeddings with the correct head dimension
- Handle Gemma-specific weight normalization in the embedding layer
Step 2: Add Inference Support With KV Cache
Extend the decoder layer to support autoregressive inference by integrating TE's InferenceParams class. This manages the key-value cache across generation steps, storing computed K/V tensors from previous tokens to avoid redundant computation. Configure the cache for the prefill phase (processing the full prompt) and the decode phase (generating one token at a time).
What happens:
- During prefill: all prompt tokens are processed, K/V tensors are cached
- During decode: only the new token is processed, using cached K/V from previous steps
- InferenceParams tracks sequence offsets and cache allocation automatically
Step 3: Create Model Entry Point for Generation
Build a TEGemmaForCausalLM class that wraps HuggingFace's GemmaForCausalLM and provides compatibility with the generate() API. This includes overriding the prepare_inputs_for_generation method to pass InferenceParams through the generation loop, managing cache initialization at the start of generation, and ensuring output format compatibility.
Key considerations:
- Override prepare_inputs_for_generation to inject InferenceParams
- Handle the transition from prefill to decode phase
- Maintain compatibility with HuggingFace's generation strategies (greedy, beam search, sampling)
Step 4: Load Pretrained Gemma Weights
Map HuggingFace Gemma checkpoint weights to the TE layer parameter format. This involves translating parameter names between the HuggingFace naming convention and TE's fused layer naming convention, similar to the LLaMA weight mapping but adapted for Gemma-specific parameter organization.
Weight mapping:
- Gemma's attention Q/K/V projections map to TE's fused layernorm_qkv weights
- Gemma's MLP gate/up/down projections map to TE's fused layernorm_mlp weights
- Embedding layer weights may need special handling for Gemma's normalization factor
Step 5: Run FP8 Inference or Training
Execute inference using model.generate() or run training with te.autocast for FP8 precision. For inference, the KV cache is automatically managed across generation steps. For training, the autocast context wraps the forward pass to enable FP8 computation on supported hardware.
Key considerations:
- For inference: KV cache is allocated once and reused across decode steps
- For training: use DelayedScaling or Float8CurrentScaling recipe
- The FP8 recipe configuration is the same as any other TE workflow