Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:NVIDIA TransformerEngine Gemma TE Model Entry Point

From Leeroopedia


Overview

Creating a complete TE-accelerated Gemma model with FP8 training and autoregressive generation capabilities.

Description

The Gemma TE model entry point builds a full GemmaForCausalLM model where the standard HuggingFace decoder layers are replaced with TransformerEngine's TEGemmaDecoderLayer instances. The entry point adds several capabilities beyond what the base HF model provides:

  • Integrated InferenceParams: Manages KV caches for efficient autoregressive generation
  • RoPE Embeddings: A single RotaryPositionEmbedding instance shared across all layers, which is compatible with CUDA graph capture
  • Custom generate() Method: Implements a two-phase generation pipeline (context/prefill phase and generation/decode phase) with explicit management of the InferenceParams lifecycle
  • FP8 Recipe Selection: Hardware-adaptive FP8 recipe selection via get_default_fp8_recipe()
  • CUDA Graph Support: An optional subclass (TEGemmaForCausalLMCudaGraphs) that captures the two generation phases as CUDA graph callables for reduced CPU overhead

The model uses a monkey-patching approach: before the HF GemmaForCausalLM is initialized, the GemmaDecoderLayer class is temporarily replaced with TEGemmaDecoderLayer via a context manager. This means the HF model construction code runs as normal, but each decoder layer instance is actually a TE TransformerLayer.

Two wrapper classes orchestrate the forward pass for inference:

  • GemmaModelWrapper: Handles the context/prefill phase -- processes the full input sequence through all layers, applies final layer norm, and produces logits
  • GemmaGenerationWrapper: Handles the generation/decode phase -- processes a single token per sequence, selects the next token via argmax, and prepares the embedding for the next step

These wrappers own no parameters; they standardize buffer usage, attention masks, rotary embeddings, and KV cache flow for TE-optimized inference. Their separation enables independent CUDA graph capture.

Theoretical Basis

The entry point extends HuggingFace's GemmaForCausalLM by replacing its model.layers with TE TransformerLayer instances and adding a custom generate() method that explicitly manages the InferenceParams lifecycle.

The generation process follows:

  1. Context/Prefill Phase: The full input prompt is processed in a single forward pass. All K/V pairs are computed and cached in InferenceParams. The attention mask type is "padding_causal" to handle variable-length inputs within the batch.
  1. Generation/Decode Phase: Each iteration processes a single new token per sequence. Only the new token's Q is computed; K/V for the new token are computed and appended to the cache, and attention is computed against the full cached K/V. The attention mask type changes to "padding".

The two-phase separation enables CUDA graph capture, which records the GPU kernel launch sequence once and replays it with minimal CPU overhead. This requires:

  • Static buffer shapes: Hidden state buffers and generation buffers have fixed dimensions
  • Static pointer addresses: Buffers are allocated once and reused via in-place operations (.copy_(), .data[:] =)
  • Shared RoPE: A single RotaryPositionEmbedding instance avoids per-layer allocation

Usage

Use this as the main entry point for TE-accelerated Gemma training and inference:

  • For training/fine-tuning: Create a TEGemmaForCausalLM instance and use its forward() method with te.pytorch.autocast for FP8 mixed precision
  • For inference without CUDA graphs: Use TEGemmaForCausalLM.generate() with input token IDs and max_new_tokens
  • For inference with CUDA graphs: Use TEGemmaForCausalLMCudaGraphs, call record() once to capture the graph, then use generate() for subsequent batches

The typical initialization flow:

  1. Load the model configuration with additional TE-specific fields (e.g., fp8, fuse_qkv_params, is_paged)
  2. Create the model with load_te_model(TEGemmaForCausalLM, config)
  3. For CUDA graphs: call model.record()
  4. Run generation with model.generate(input_ids, max_new_tokens=N)

Related

Sources

Domains

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment