Principle:NVIDIA TransformerEngine Gemma TE Model Entry Point
Overview
Creating a complete TE-accelerated Gemma model with FP8 training and autoregressive generation capabilities.
Description
The Gemma TE model entry point builds a full GemmaForCausalLM model where the standard HuggingFace decoder layers are replaced with TransformerEngine's TEGemmaDecoderLayer instances. The entry point adds several capabilities beyond what the base HF model provides:
- Integrated InferenceParams: Manages KV caches for efficient autoregressive generation
- RoPE Embeddings: A single
RotaryPositionEmbeddinginstance shared across all layers, which is compatible with CUDA graph capture - Custom
generate()Method: Implements a two-phase generation pipeline (context/prefill phase and generation/decode phase) with explicit management of the InferenceParams lifecycle - FP8 Recipe Selection: Hardware-adaptive FP8 recipe selection via
get_default_fp8_recipe() - CUDA Graph Support: An optional subclass (
TEGemmaForCausalLMCudaGraphs) that captures the two generation phases as CUDA graph callables for reduced CPU overhead
The model uses a monkey-patching approach: before the HF GemmaForCausalLM is initialized, the GemmaDecoderLayer class is temporarily replaced with TEGemmaDecoderLayer via a context manager. This means the HF model construction code runs as normal, but each decoder layer instance is actually a TE TransformerLayer.
Two wrapper classes orchestrate the forward pass for inference:
GemmaModelWrapper: Handles the context/prefill phase -- processes the full input sequence through all layers, applies final layer norm, and produces logitsGemmaGenerationWrapper: Handles the generation/decode phase -- processes a single token per sequence, selects the next token via argmax, and prepares the embedding for the next step
These wrappers own no parameters; they standardize buffer usage, attention masks, rotary embeddings, and KV cache flow for TE-optimized inference. Their separation enables independent CUDA graph capture.
Theoretical Basis
The entry point extends HuggingFace's GemmaForCausalLM by replacing its model.layers with TE TransformerLayer instances and adding a custom generate() method that explicitly manages the InferenceParams lifecycle.
The generation process follows:
- Context/Prefill Phase: The full input prompt is processed in a single forward pass. All K/V pairs are computed and cached in
InferenceParams. The attention mask type is"padding_causal"to handle variable-length inputs within the batch.
- Generation/Decode Phase: Each iteration processes a single new token per sequence. Only the new token's Q is computed; K/V for the new token are computed and appended to the cache, and attention is computed against the full cached K/V. The attention mask type changes to
"padding".
The two-phase separation enables CUDA graph capture, which records the GPU kernel launch sequence once and replays it with minimal CPU overhead. This requires:
- Static buffer shapes: Hidden state buffers and generation buffers have fixed dimensions
- Static pointer addresses: Buffers are allocated once and reused via in-place operations (
.copy_(),.data[:] =) - Shared RoPE: A single
RotaryPositionEmbeddinginstance avoids per-layer allocation
Usage
Use this as the main entry point for TE-accelerated Gemma training and inference:
- For training/fine-tuning: Create a
TEGemmaForCausalLMinstance and use itsforward()method withte.pytorch.autocastfor FP8 mixed precision - For inference without CUDA graphs: Use
TEGemmaForCausalLM.generate()with input token IDs andmax_new_tokens - For inference with CUDA graphs: Use
TEGemmaForCausalLMCudaGraphs, callrecord()once to capture the graph, then usegenerate()for subsequent batches
The typical initialization flow:
- Load the model configuration with additional TE-specific fields (e.g.,
fp8,fuse_qkv_params,is_paged) - Create the model with
load_te_model(TEGemmaForCausalLM, config) - For CUDA graphs: call
model.record() - Run generation with
model.generate(input_ids, max_new_tokens=N)