Implementation:NVIDIA TransformerEngine TEGemmaDecoderLayer
Overview
TE-accelerated replacement for HuggingFace's GemmaDecoderLayer using TransformerEngine's TransformerLayer.
Description
TEGemmaDecoderLayer subclasses te.pytorch.TransformerLayer to serve as a drop-in replacement for HuggingFace's GemmaDecoderLayer within the Gemma model. It maps GemmaConfig fields to TE TransformerLayer constructor parameters with Gemma-specific settings: zero_centered_gamma=True, activation="geglu", bias=False, normalization="RMSNorm", and attn_input_format="bshd".
The class includes layer index tracking (layer_idx) which is incremented by 1 when passed to TE (since TE layer numbering starts from 1, not 0). The forward() method filters out HuggingFace-specific keyword arguments that are not applicable to TE's TransformerLayer and passes through rotary_pos_emb for RoPE support, as well as inference_params for KV cache management.
This is a Wrapper Doc.
Source
docs/examples/te_gemma/te_gemma.py, class TEGemmaDecoderLayer at lines 139-191.
Signature
class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):
self.gemma_config = config
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=config.fuse_qkv_params,
normalization="RMSNorm",
activation="geglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
kv_channels=self.gemma_config.head_dim,
layer_number=(
layer_idx + 1
), # Layer numbers in TE starts from 1, not 0 like in the HF.
zero_centered_gamma=True,
)
def forward(self, *args, **kwargs):
# Filters out HF-specific kwargs:
# position_ids, past_key_value, output_attentions, use_cache, cache_position
# Extracts rope_emb from kwargs
# Delegates to TransformerLayer.forward with rotary_pos_emb=rope_emb
# Returns tuple (output,) for HF compatibility
return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),)
I/O
Input:
config:GemmaConfig-- HuggingFace Gemma configuration object containing model hyperparameterslayer_idx:int-- Zero-based layer index (converted to 1-based for TE internally)*args,**kwargs: Additional positional and keyword arguments for HF compatibility
Forward Input:
hidden_states:torch.Tensor-- Input hidden states tensorrope_emb:torch.Tensor(optional, via kwargs) -- Rotary position embedding tensorinference_params:InferenceParams(optional, via kwargs) -- KV cache manager for inferenceattention_mask:torch.Tensor(optional, via kwargs) -- Attention mask tensorself_attn_mask_type:str(optional, via kwargs) -- Mask type, e.g."padding_causal"
Output:
tuplecontaining a singletorch.Tensor-- The layer output wrapped in a tuple for HF compatibility
Key Parameters
| Parameter | Type | Description |
|---|---|---|
hidden_size |
int |
Model hidden dimension from config.hidden_size
|
ffn_hidden_size |
int |
FFN intermediate dimension from config.intermediate_size
|
num_attention_heads |
int |
Number of query attention heads |
num_gqa_groups |
int |
Number of KV heads for grouped-query attention |
kv_channels |
int |
Per-head dimension from config.head_dim
|
zero_centered_gamma |
bool |
True; RMSNorm weight initialized to 0, effective gamma = 1 + weight
|
activation |
str |
"geglu"; Gated GELU activation for the FFN
|
fuse_qkv_params |
bool |
Whether to fuse QKV projections into a single parameter |
Notes
- The forward method filters out HF-specific keyword arguments (
position_ids,past_key_value,output_attentions,use_cache,cache_position) that are not used by TE'sTransformerLayer. - The output is wrapped in a tuple
(output,)to match the return signature expected by HuggingFace'sGemmaModellayer iteration. - Layer numbering is adjusted from 0-based (HF convention) to 1-based (TE convention) via
layer_number = layer_idx + 1. - Both
hidden_dropoutandattention_dropoutare set to 0, consistent with Gemma's architecture.