Principle:NVIDIA TransformerEngine Gemma TE Decoder Layer
Overview
Adapting TransformerEngine's TransformerLayer for the Gemma model architecture with GeGLU activation and inference support.
Description
Gemma models use a specific configuration within the standard Transformer architecture: RMSNorm with zero-centered gamma, GeGLU activation (Gated GELU), configurable head_dim, and no biases. The TE decoder layer adapter maps GemmaConfig fields to TE TransformerLayer parameters and adds inference support (KV caching) via InferenceParams.
Unlike the Llama TE adapter, the Gemma version also includes CUDA graph support for the generation phase. The adapter is designed as a drop-in replacement for HuggingFace's GemmaDecoderLayer, allowing the rest of the HF model infrastructure to remain unchanged while benefiting from TE's fused kernels and FP8 acceleration.
The key configuration mapping from Gemma to TE includes:
| Gemma Parameter | TE TransformerLayer Parameter | Value |
|---|---|---|
config.hidden_size |
hidden_size |
Model hidden dimension |
config.intermediate_size |
ffn_hidden_size |
FFN intermediate dimension |
config.num_attention_heads |
num_attention_heads |
Number of query heads |
config.num_key_value_heads |
num_gqa_groups |
Number of KV heads (GQA) |
config.head_dim |
kv_channels |
Per-head dimension |
config.rms_norm_eps |
layernorm_epsilon |
RMSNorm epsilon |
| (hardcoded) | bias |
False
|
| (hardcoded) | normalization |
"RMSNorm"
|
| (hardcoded) | activation |
"geglu"
|
| (hardcoded) | attn_input_format |
"bshd"
|
| (hardcoded) | zero_centered_gamma |
True
|
config.fuse_qkv_params |
fuse_qkv_params |
Whether to fuse QKV weights |
Theoretical Basis
The Gemma decoder layer follows the same core Transformer architecture as other decoder-only models but with several Gemma-specific design choices:
- GeGLU Activation: The feed-forward network uses Gated GELU (GeGLU) instead of standard GELU or SiLU. GeGLU splits the first linear projection into two halves -- a gate and a value -- applying GELU to the gate before element-wise multiplication with the value. This provides a smoother gating mechanism compared to SwiGLU used in Llama.
- Zero-Centered Gamma in RMSNorm: Unlike standard RMSNorm where
gammais initialized to 1, Gemma initializes theweightparameter to 0 with the effective gamma computed as1 + weight. This is equivalent at initialization but provides different gradient dynamics during training, allowing the model to learn deviations from the identity normalization more easily.
- Configurable Head Dimension: While many models derive
head_dimashidden_size / num_heads, Gemma allows an independently configurablehead_dim. This permits architectures where the per-head dimension does not evenly divide the hidden size when using grouped-query attention (GQA).
- No Biases: All linear projections (QKV, output projection, FFN layers) omit bias terms, following the trend in modern large language models toward bias-free architectures for improved training stability.
Usage
Use this principle when accelerating pretrained HuggingFace Gemma models with TransformerEngine FP8 for training or inference. The decoder layer adapter allows replacing the standard HF GemmaDecoderLayer with the TE-optimized version while preserving compatibility with the rest of the HF model code, including the generation pipeline and weight loading.
This is particularly useful when:
- Fine-tuning Gemma models with FP8 mixed precision for faster training
- Running Gemma inference with KV caching and optional CUDA graph capture
- Integrating Gemma into pipelines that leverage TE's fused attention and MLP kernels