Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:NVIDIA TransformerEngine Accelerate HF Gemma With TE

From Leeroopedia


Knowledge Sources
Domains LLMs, FP8_Inference, HuggingFace_Integration
Last Updated 2026-02-07 21:00 GMT

Overview

End-to-end process for accelerating HuggingFace Gemma models with Transformer Engine, enabling FP8 precision for both training and inference with KV cache support.

Description

This workflow demonstrates how to integrate NVIDIA Transformer Engine with HuggingFace's Gemma model for both training and autoregressive inference (text generation). It follows the same monkey-patching strategy as the LLaMA integration but adds inference-specific features: KV cache management through TE's InferenceParams, autoregressive token generation support, and compatibility with HuggingFace's generate() API. The implementation creates a TEGemmaDecoderLayer that inherits from te.TransformerLayer and a TEGemmaForCausalLM entry point for generation.

Key outputs:

  • A HuggingFace-compatible Gemma model with TE-accelerated decoder layers
  • FP8-accelerated autoregressive text generation with KV caching
  • Weight loading from pretrained HuggingFace Gemma checkpoints

Usage

Execute this workflow when you need to accelerate HuggingFace Gemma model inference or training with FP8 precision, particularly when using the model.generate() API for text generation. This is suitable for both fine-tuning and serving Gemma models on NVIDIA GPUs with FP8 support.

Execution Steps

Step 1: Define TE Decoder Layer for Gemma

Create a TEGemmaDecoderLayer class that inherits from te.pytorch.TransformerLayer and maps Gemma's architecture configuration to TE parameters. Gemma uses RMSNorm, GeGLU activation, and has specific attention patterns that must be correctly configured. Initialize RotaryPositionEmbedding for positional encoding.

Key considerations:

  • Gemma uses GeGLU activation (not SwiGLU) in its MLP
  • Configure hidden_size, intermediate_size, and attention head counts from GemmaConfig
  • Set up RoPE embeddings with the correct head dimension
  • Handle Gemma-specific weight normalization in the embedding layer

Step 2: Add Inference Support With KV Cache

Extend the decoder layer to support autoregressive inference by integrating TE's InferenceParams class. This manages the key-value cache across generation steps, storing computed K/V tensors from previous tokens to avoid redundant computation. Configure the cache for the prefill phase (processing the full prompt) and the decode phase (generating one token at a time).

What happens:

  • During prefill: all prompt tokens are processed, K/V tensors are cached
  • During decode: only the new token is processed, using cached K/V from previous steps
  • InferenceParams tracks sequence offsets and cache allocation automatically

Step 3: Create Model Entry Point for Generation

Build a TEGemmaForCausalLM class that wraps HuggingFace's GemmaForCausalLM and provides compatibility with the generate() API. This includes overriding the prepare_inputs_for_generation method to pass InferenceParams through the generation loop, managing cache initialization at the start of generation, and ensuring output format compatibility.

Key considerations:

  • Override prepare_inputs_for_generation to inject InferenceParams
  • Handle the transition from prefill to decode phase
  • Maintain compatibility with HuggingFace's generation strategies (greedy, beam search, sampling)

Step 4: Load Pretrained Gemma Weights

Map HuggingFace Gemma checkpoint weights to the TE layer parameter format. This involves translating parameter names between the HuggingFace naming convention and TE's fused layer naming convention, similar to the LLaMA weight mapping but adapted for Gemma-specific parameter organization.

Weight mapping:

  • Gemma's attention Q/K/V projections map to TE's fused layernorm_qkv weights
  • Gemma's MLP gate/up/down projections map to TE's fused layernorm_mlp weights
  • Embedding layer weights may need special handling for Gemma's normalization factor

Step 5: Run FP8 Inference or Training

Execute inference using model.generate() or run training with te.autocast for FP8 precision. For inference, the KV cache is automatically managed across generation steps. For training, the autocast context wraps the forward pass to enable FP8 computation on supported hardware.

Key considerations:

  • For inference: KV cache is allocated once and reused across decode steps
  • For training: use DelayedScaling or Float8CurrentScaling recipe
  • The FP8 recipe configuration is the same as any other TE workflow

Execution Diagram

GitHub URL

Workflow Repository